perf: 优化 tree nodes 避免太慢 (#12472)

* perf: 优化 tree nodes 避免太慢

perf: 优化大量资产上的资产数生成比较慢

perf: 优化节点树

perf: 修改 tree nooooooooodes

perf: 优化一些 api 比较大的问题

perf: 优化平台 api

perf: 分页返回同步树

perf: 优化节点树

perf: 深度优化节点树

* perf: remove unused config

---------

Co-authored-by: ibuler <ibuler@qq.com>
pull/12481/head
fit2bot 2024-01-02 16:11:56 +08:00 committed by GitHub
parent e80a0e41ba
commit 2fcbfe9f21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 508 additions and 236 deletions

View File

@ -21,7 +21,6 @@ from common.drf.filters import BaseFilterSet, AttrRulesFilterBackend
from common.utils import get_logger, is_uuid
from orgs.mixins import generics
from orgs.mixins.api import OrgBulkModelViewSet
from ..mixin import NodeFilterMixin
from ...notifications import BulkUpdatePlatformSkipAssetUserMsg
logger = get_logger(__file__)
@ -86,7 +85,7 @@ class AssetFilterSet(BaseFilterSet):
return queryset.filter(protocols__name__in=value).distinct()
class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
"""
API endpoint that allows Asset to be viewed or edited.
"""
@ -114,9 +113,7 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
]
def get_queryset(self):
queryset = super().get_queryset() \
.prefetch_related('nodes', 'protocols') \
.select_related('platform', 'domain')
queryset = super().get_queryset()
if queryset.model is not Asset:
queryset = queryset.select_related('asset_ptr')
return queryset

View File

@ -20,14 +20,15 @@ class DomainViewSet(OrgBulkModelViewSet):
filterset_fields = ("name",)
search_fields = filterset_fields
ordering = ('name',)
serializer_classes = {
'default': serializers.DomainSerializer,
'list': serializers.DomainListSerializer,
}
def get_serializer_class(self):
if self.request.query_params.get('gateway'):
return serializers.DomainWithGatewaySerializer
return serializers.DomainSerializer
def get_queryset(self):
return super().get_queryset().prefetch_related('assets')
return super().get_serializer_class()
class GatewayViewSet(HostViewSet):

View File

@ -2,7 +2,7 @@ from typing import List
from rest_framework.request import Request
from assets.models import Node, Protocol
from assets.models import Node, Platform, Protocol
from assets.utils import get_node_from_request, is_query_node_all_assets
from common.utils import lazyproperty, timeit
@ -71,37 +71,43 @@ class SerializeToTreeNodeMixin:
return 'file'
@timeit
def serialize_assets(self, assets, node_key=None, pid=None):
if node_key is None:
get_pid = lambda asset: getattr(asset, 'parent_key', '')
else:
get_pid = lambda asset: node_key
def serialize_assets(self, assets, node_key=None, get_pid=None):
if not get_pid and not node_key:
get_pid = lambda asset, platform: getattr(asset, 'parent_key', '')
sftp_asset_ids = Protocol.objects.filter(name='sftp') \
.values_list('asset_id', flat=True)
sftp_asset_ids = list(sftp_asset_ids)
data = [
{
sftp_asset_ids = set(sftp_asset_ids)
platform_map = {p.id: p for p in Platform.objects.all()}
data = []
for asset in assets:
platform = platform_map.get(asset.platform_id)
if not platform:
continue
pid = node_key or get_pid(asset, platform)
if not pid or pid.isdigit():
continue
data.append({
'id': str(asset.id),
'name': asset.name,
'title': f'{asset.address}\n{asset.comment}',
'pId': pid or get_pid(asset),
'title': f'{asset.address}\n{asset.comment}'.strip(),
'pId': pid,
'isParent': False,
'open': False,
'iconSkin': self.get_icon(asset),
'iconSkin': self.get_icon(platform),
'chkDisabled': not asset.is_active,
'meta': {
'type': 'asset',
'data': {
'platform_type': asset.platform.type,
'platform_type': platform.type,
'org_name': asset.org_name,
'sftp': asset.id in sftp_asset_ids,
'name': asset.name,
'address': asset.address
},
}
}
for asset in assets
]
})
return data

View File

@ -29,7 +29,9 @@ class AssetPlatformViewSet(JMSModelViewSet):
}
def get_queryset(self):
queryset = super().get_queryset()
queryset = super().get_queryset().prefetch_related(
'protocols', 'automation'
)
queryset = queryset.filter(type__in=AllTypes.get_types_values())
return queryset

View File

@ -126,6 +126,8 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
include_assets = self.request.query_params.get('assets', '0') == '1'
if not self.instance or not include_assets:
return Asset.objects.none()
if self.instance.is_org_root():
return Asset.objects.none()
if query_all:
assets = self.instance.get_all_assets()
else:

View File

@ -268,7 +268,7 @@ class AllTypes(ChoicesMixin):
meta = {'type': 'category', 'category': category.value, '_type': category.value}
category_node = cls.choice_to_node(category, 'ROOT', meta=meta)
category_count = category_type_mapper.get(category, 0)
category_node['name'] += f'({category_count})'
category_node['name'] += f' ({category_count})'
nodes.append(category_node)
# Type 格式化
@ -277,7 +277,7 @@ class AllTypes(ChoicesMixin):
meta = {'type': 'type', 'category': category.value, '_type': tp.value}
tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta)
tp_count = category_type_mapper.get(category + '_' + tp, 0)
tp_node['name'] += f'({tp_count})'
tp_node['name'] += f' ({tp_count})'
platforms = tp_platforms.get(category + '_' + tp, [])
if not platforms:
tp_node['isParent'] = False
@ -286,7 +286,7 @@ class AllTypes(ChoicesMixin):
# Platform 格式化
for p in platforms:
platform_node = cls.platform_to_node(p, tp_node['id'], include_asset)
platform_node['name'] += f'({platform_count.get(p.id, 0)})'
platform_node['name'] += f' ({platform_count.get(p.id, 0)})'
nodes.append(platform_node)
return nodes

View File

@ -63,11 +63,10 @@ class NodeFilterBackend(filters.BaseFilterBackend):
query_all = is_query_node_all_assets(request)
if query_all:
return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') |
Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key)
).distinct()
else:
print("Query query origin: ", queryset.count())
return queryset.filter(nodes__key=node.key).distinct()

View File

@ -13,7 +13,7 @@ from django.db.transaction import atomic
from django.utils.translation import gettext_lazy as _, gettext
from common.db.models import output_as_string
from common.utils import get_logger
from common.utils import get_logger, timeit
from common.utils.lock import DistributedLock
from orgs.mixins.models import OrgManager, JMSOrgBaseModel
from orgs.models import Organization
@ -195,11 +195,6 @@ class FamilyMixin:
ancestor_keys = self.get_ancestor_keys(with_self=with_self)
return self.__class__.objects.filter(key__in=ancestor_keys)
# @property
# def parent_key(self):
# parent_key = ":".join(self.key.split(":")[:-1])
# return parent_key
def compute_parent_key(self):
return compute_parent_key(self.key)
@ -349,29 +344,26 @@ class NodeAllAssetsMappingMixin:
return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id)
@classmethod
@timeit
def generate_node_all_asset_ids_mapping(cls, org_id):
from .asset import Asset
logger.info(f'Generate node asset mapping: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
logger.info(f'Generate node asset mapping: org_id={org_id}')
t1 = time.time()
with tmp_to_org(org_id):
node_ids_key = Node.objects.annotate(
char_id=output_as_string('id')
).values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = Asset.nodes.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
node_id_ancestor_keys_mapping = {
node_id: cls.get_node_ancestor_keys(node_key, with_self=True)
for node_id, node_key in node_ids_key
}
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = cls.assets.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
nodeid_assetsid_mapping = defaultdict(set)
for node_id, asset_id in nodes_asset_ids:
nodeid_assetsid_mapping[node_id].add(asset_id)
@ -386,7 +378,7 @@ class NodeAllAssetsMappingMixin:
mapping[ancestor_key].update(asset_ids)
t3 = time.time()
logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2 - t1, t3 - t2))
logger.info('Generate asset nodes mapping, DB query: {:.2f}s, mapping: {:.2f}s'.format(t2 - t1, t3 - t2))
return mapping
@ -436,6 +428,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
return asset_ids
@classmethod
@timeit
def get_nodes_all_assets(cls, *nodes):
from .asset import Asset
node_ids = set()
@ -559,11 +552,6 @@ class Node(JMSOrgBaseModel, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
def __str__(self):
return self.full_value
# def __eq__(self, other):
# if not other:
# return False
# return self.id == other.id
#
def __gt__(self, other):
self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')]

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.db.models import Count
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
@ -7,18 +8,15 @@ from common.serializers import ResourceLabelsMixin
from common.serializers.fields import ObjectRelatedField
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .gateway import GatewayWithAccountSecretSerializer
from ..models import Domain, Asset
from ..models import Domain
__all__ = ['DomainSerializer', 'DomainWithGatewaySerializer']
__all__ = ['DomainSerializer', 'DomainWithGatewaySerializer', 'DomainListSerializer']
class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
gateways = ObjectRelatedField(
many=True, required=False, label=_('Gateway'), read_only=True,
)
assets = ObjectRelatedField(
many=True, required=False, queryset=Asset.objects, label=_('Asset')
)
class Meta:
model = Domain
@ -30,7 +28,9 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
def to_representation(self, instance):
data = super().to_representation(instance)
assets = data['assets']
assets = data.get('assets')
if assets is None:
return data
gateway_ids = [str(i['id']) for i in data['gateways']]
data['assets'] = [i for i in assets if str(i['id']) not in gateway_ids]
return data
@ -49,6 +49,20 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
return queryset
class DomainListSerializer(DomainSerializer):
assets_amount = serializers.IntegerField(label=_('Assets amount'), read_only=True)
class Meta(DomainSerializer.Meta):
fields = list(set(DomainSerializer.Meta.fields + ['assets_amount']) - {'assets'})
@classmethod
def setup_eager_loading(cls, queryset):
queryset = queryset.annotate(
assets_amount=Count('assets'),
)
return queryset
class DomainWithGatewaySerializer(serializers.ModelSerializer):
gateways = GatewayWithAccountSecretSerializer(many=True, read_only=True)

View File

@ -80,10 +80,11 @@ RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs):
logger.debug("Asset pre delete signal recv: {}".format(instance))
node_ids = Node.objects.filter(assets=instance) \
.distinct().values_list('id', flat=True)
setattr(instance, RELATED_NODE_IDS, node_ids)
node_ids = list(node_ids)
logger.debug("Asset pre delete signal recv: {}, node_ids: {}".format(instance, node_ids))
setattr(instance, RELATED_NODE_IDS, list(node_ids))
m2m_changed.send(
sender=Asset.nodes.through, instance=instance,
reverse=False, model=Node, pk_set=node_ids,
@ -93,8 +94,8 @@ def on_asset_delete(instance: Asset, using, **kwargs):
@receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs):
logger.debug("Asset post delete signal recv: {}".format(instance))
node_ids = getattr(instance, RELATED_NODE_IDS, [])
logger.debug("Asset post delete signal recv: {}, node_ids: {}".format(instance, node_ids))
if node_ids:
m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False,

View File

@ -15,8 +15,8 @@ from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__)
@on_transaction_commit
@receiver(m2m_changed, sender=Asset.nodes.through)
@on_transaction_commit
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
@ -37,7 +37,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
update_nodes_assets_amount(node_ids=node_ids)
@merge_delay_run(ttl=5)
@merge_delay_run(ttl=30)
def update_nodes_assets_amount(node_ids=()):
nodes = Node.objects.filter(id__in=node_ids)
nodes = Node.get_ancestor_queryset(nodes)

View File

@ -21,7 +21,7 @@ logger = get_logger(__name__)
node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)()
@merge_delay_run(ttl=5)
@merge_delay_run(ttl=30)
def expire_node_assets_mapping(org_ids=()):
logger.debug("Recv asset nodes changed signal, expire memery node asset mapping")
# 所有进程清除(自己的 memory 数据)
@ -53,8 +53,9 @@ def on_node_post_delete(sender, instance, **kwargs):
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs):
expire_node_assets_mapping(org_ids=(instance.org_id,))
def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
if action.startswith('post'):
expire_node_assets_mapping(org_ids=(instance.org_id,))
@receiver(django_ready)

View File

@ -98,12 +98,17 @@ class QuerySetMixin:
return queryset
if self.action == 'metadata':
queryset = queryset.none()
if self.action in ['list', 'metadata']:
serializer_class = self.get_serializer_class()
if serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
queryset = serializer_class.setup_eager_loading(queryset)
return queryset
def paginate_queryset(self, queryset):
page = super().paginate_queryset(queryset)
serializer_class = self.get_serializer_class()
if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
ids = [i.id for i in page]
page = self.get_queryset().filter(id__in=ids)
page = serializer_class.setup_eager_loading(page)
return page
class ExtraFilterFieldsMixin:
"""

View File

@ -65,7 +65,7 @@ class EventLoopThread(threading.Thread):
_loop_thread = EventLoopThread()
_loop_thread.setDaemon(True)
_loop_thread.daemon = True
_loop_thread.start()
executor = ThreadPoolExecutor(
max_workers=10,

View File

@ -62,7 +62,7 @@ def digest_sql_query():
method = current_request.method
path = current_request.get_full_path()
print(">>> [{}] {}".format(method, path))
print(">>>. [{}] {}".format(method, path))
for table_name, queries in table_queries.items():
if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
continue
@ -77,9 +77,9 @@ def digest_sql_query():
sql = query['sql']
if not sql or not sql.startswith('SELECT'):
continue
print('\t{}. {}'.format(i, sql))
print('\t{}.[{}s] {}'.format(i, round(float(query['time']), 2), sql[:1000]))
logger.debug(">>> [{}] {}".format(method, path))
# logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters:
logger.debug("Query {:3} times using {:.2f}s {}".format(
counter.counter, counter.time, name)

View File

@ -220,7 +220,7 @@ def timeit(func):
now = time.time()
result = func(*args, **kwargs)
using = (time.time() - now) * 1000
msg = "End call {}, using: {:.1f}ms".format(name, using)
msg = "Ends call: {}, using: {:.1f}ms".format(name, using)
logger.debug(msg)
return result

View File

@ -1,18 +1,16 @@
from functools import wraps
import threading
from functools import wraps
from django.db import transaction
from redis_lock import (
Lock as RedisLock, NotAcquired, UNLOCK_SCRIPT,
EXTEND_SCRIPT, RESET_SCRIPT, RESET_ALL_SCRIPT
)
from redis import Redis
from django.db import transaction
from common.utils import get_logger
from common.utils.inspect import copy_function_args
from common.utils.connection import get_redis_client
from jumpserver.const import CONFIG
from common.local import thread_local
from common.utils import get_logger
from common.utils.connection import get_redis_client
from common.utils.inspect import copy_function_args
logger = get_logger(__file__)
@ -76,6 +74,7 @@ class DistributedLock(RedisLock):
# 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy):
return func(*args, **kwds)
return inner
@classmethod
@ -95,7 +94,6 @@ class DistributedLock(RedisLock):
if self.locked():
owner_id = self.get_owner_id()
local_owner_id = getattr(thread_local, self.name, None)
if local_owner_id and owner_id == local_owner_id:
return True
return False
@ -140,14 +138,16 @@ class DistributedLock(RedisLock):
logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
return
else:
self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
self._raise_exc_with_log(
f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
def _release_on_reentrant_locked_by_me(self):
logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name}')
id = getattr(thread_local, self.name, None)
if id != self.id:
raise PermissionError(f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
raise PermissionError(
f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
try:
# 这里要保证先删除 thread_local 的标记,
delattr(thread_local, self.name)
@ -191,7 +191,7 @@ class DistributedLock(RedisLock):
# 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit:
logger.debug(
f'Release lock on transaction commit ... :lock_id={self.id} lock={self.name}')
f'Release lock on transaction commit:lock_id={self.id} lock={self.name}')
transaction.on_commit(_release)
else:
_release()

View File

@ -531,6 +531,7 @@ class Config(dict):
'SYSLOG_SOCKTYPE': 2,
'PERM_EXPIRED_CHECK_PERIODIC': 60 * 60,
'PERM_TREE_REGEN_INTERVAL': 1,
'FLOWER_URL': "127.0.0.1:5555",
'LANGUAGE_CODE': 'zh',
'TIME_ZONE': 'Asia/Shanghai',

View File

@ -208,6 +208,7 @@ OPERATE_LOG_ELASTICSEARCH_CONFIG = CONFIG.OPERATE_LOG_ELASTICSEARCH_CONFIG
MAX_LIMIT_PER_PAGE = CONFIG.MAX_LIMIT_PER_PAGE
DEFAULT_PAGE_SIZE = CONFIG.DEFAULT_PAGE_SIZE
PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
# Magnus DB Port
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS

View File

@ -21,7 +21,7 @@ LOGGING = {
},
'main': {
'datefmt': '%Y-%m-%d %H:%M:%S',
'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s',
'format': '%(asctime)s [%(levelname).4s] %(message)s',
},
'exception': {
'datefmt': '%Y-%m-%d %H:%M:%S',

View File

@ -75,7 +75,7 @@ model_cache_field_mapper = {
class OrgResourceStatisticsRefreshUtil:
@staticmethod
@merge_delay_run(ttl=5)
@merge_delay_run(ttl=30)
def refresh_org_fields(org_fields=()):
for org, cache_field_name in org_fields:
OrgResourceStatisticsCache(org).expire(*cache_field_name)
@ -104,7 +104,7 @@ def on_post_delete_refresh_org_resource_statistics_cache(sender, instance, **kwa
def _refresh_session_org_resource_statistics_cache(instance: Session):
cache_field_name = [
'total_count_online_users', 'total_count_online_sessions',
'total_count_today_active_assets','total_count_today_failed_sessions'
'total_count_today_active_assets', 'total_count_today_failed_sessions'
]
org_cache = OrgResourceStatisticsCache(instance.org)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
#
from orgs.mixins.api import OrgBulkModelViewSet
from perms import serializers
from perms.filters import AssetPermissionFilter
@ -13,7 +14,10 @@ class AssetPermissionViewSet(OrgBulkModelViewSet):
资产授权列表的增删改查api
"""
model = AssetPermission
serializer_class = serializers.AssetPermissionSerializer
serializer_classes = {
'default': serializers.AssetPermissionSerializer,
'list': serializers.AssetPermissionListSerializer,
}
filterset_class = AssetPermissionFilter
search_fields = ('name',)
ordering = ('name',)

View File

@ -1,16 +1,14 @@
from django.conf import settings
from rest_framework.response import Response
from assets.models import Asset
from assets.api import SerializeToTreeNodeMixin
from assets.models import Asset
from common.utils import get_logger
from ..assets import UserAllPermedAssetsApi
from .mixin import RebuildTreeMixin
from ..assets import UserAllPermedAssetsApi
logger = get_logger(__name__)
__all__ = [
'UserAllPermedAssetsAsTreeApi',
'UserUngroupAssetsAsTreeApi',
@ -31,7 +29,7 @@ class AssetTreeMixin(RebuildTreeMixin, SerializeToTreeNodeMixin):
if request.query_params.get('search'):
""" 限制返回数量, 搜索的条件不精准时,会返回大量的无意义数据 """
assets = assets[:999]
data = self.serialize_assets(assets, None)
data = self.serialize_assets(assets, 'root')
return Response(data=data)
@ -42,6 +40,7 @@ class UserAllPermedAssetsAsTreeApi(AssetTreeMixin, UserAllPermedAssetsApi):
class UserUngroupAssetsAsTreeApi(UserAllPermedAssetsAsTreeApi):
""" 用户 '未分组节点的资产(直接授权的资产)' 作为树 """
def get_assets(self):
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return super().get_assets()

View File

@ -1,6 +1,4 @@
import abc
import re
from collections import defaultdict
from urllib.parse import parse_qsl
from django.conf import settings
@ -13,7 +11,6 @@ from rest_framework.response import Response
from accounts.const import AliasAccount
from assets.api import SerializeToTreeNodeMixin
from assets.const import AllTypes
from assets.models import Asset
from assets.utils import KubernetesTree
from authentication.models import ConnectionToken
@ -38,21 +35,36 @@ class BaseUserNodeWithAssetAsTreeApi(
SelfOrPKUserMixin, RebuildTreeMixin,
SerializeToTreeNodeMixin, ListAPIView
):
page_limit = 10000
def list(self, request, *args, **kwargs):
nodes, assets = self.get_nodes_assets()
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, node_key=self.node_key_for_serialize_assets)
data = list(tree_nodes) + list(tree_assets)
return Response(data=data)
offset = int(request.query_params.get('offset', 0))
page_assets = self.get_page_assets()
if not offset:
nodes, assets = self.get_nodes_assets()
page = page_assets[:self.page_limit]
assets = [*assets, *page]
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, **self.serialize_asset_kwargs)
data = list(tree_nodes) + list(tree_assets)
else:
page = page_assets[offset:(offset + self.page_limit)]
data = self.serialize_assets(page, **self.serialize_asset_kwargs) if page else []
offset += len(page)
headers = {'X-JMS-TREE-OFFSET': offset} if offset else {}
return Response(data=data, headers=headers)
@abc.abstractmethod
def get_nodes_assets(self):
return [], []
@lazyproperty
def node_key_for_serialize_assets(self):
return None
def get_page_assets(self):
return []
@property
def serialize_asset_kwargs(self):
return {}
class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
@ -61,7 +73,6 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
def get_nodes_assets(self):
self.query_node_util = UserPermNodeUtil(self.request.user)
self.query_asset_util = UserPermAssetUtil(self.request.user)
ung_nodes, ung_assets = self._get_nodes_assets_for_ungrouped()
fav_nodes, fav_assets = self._get_nodes_assets_for_favorite()
all_nodes, all_assets = self._get_nodes_assets_for_all()
@ -69,31 +80,37 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
assets = list(ung_assets) + list(fav_assets) + list(all_assets)
return nodes, assets
def get_page_assets(self):
return self.query_asset_util.get_all_assets().annotate(parent_key=F('nodes__key'))
@timeit
def _get_nodes_assets_for_ungrouped(self):
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return [], []
node = self.query_node_util.get_ungrouped_node()
assets = self.query_asset_util.get_ungroup_assets()
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \
.prefetch_related('platform')
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField()))
return [node], assets
@lazyproperty
def query_asset_util(self):
return UserPermAssetUtil(self.user)
@timeit
def _get_nodes_assets_for_favorite(self):
node = self.query_node_util.get_favorite_node()
assets = self.query_asset_util.get_favorite_assets()
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \
.prefetch_related('platform')
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField()))
return [node], assets
@timeit
def _get_nodes_assets_for_all(self):
nodes = self.query_node_util.get_whole_tree_nodes(with_special=False)
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
assets = self.query_asset_util.get_perm_nodes_assets()
else:
assets = self.query_asset_util.get_all_assets()
assets = assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform')
assets = Asset.objects.none()
assets = assets.annotate(parent_key=F('nodes__key'))
return nodes, assets
@ -103,6 +120,7 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
# 默认展开的节点key
default_unfolded_node_key = None
@timeit
def get_nodes_assets(self):
query_node_util = UserPermNodeUtil(self.user)
query_asset_util = UserPermAssetUtil(self.user)
@ -136,14 +154,14 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
node_key = getattr(node, 'key', None)
return node_key
@lazyproperty
def node_key_for_serialize_assets(self):
return self.query_node_key or self.default_unfolded_node_key
@property
def serialize_asset_kwargs(self):
return {
'node_key': self.query_node_key or self.default_unfolded_node_key
}
class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(
SelfOrPKUserMixin, SerializeToTreeNodeMixin, ListAPIView
):
class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(BaseUserNodeWithAssetAsTreeApi):
@property
def is_sync(self):
sync = self.request.query_params.get('sync', 0)
@ -151,66 +169,52 @@ class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(
@property
def tp(self):
return self.request.query_params.get('type')
def get_assets(self):
query_asset_util = UserPermAssetUtil(self.user)
node = PermNode.objects.filter(
granted_node_rels__user=self.user, parent_key='').first()
if node:
__, assets = query_asset_util.get_node_all_assets(node.id)
else:
assets = Asset.objects.none()
return assets
def to_tree_nodes(self, assets):
if not assets:
return []
assets = assets.annotate(tp=F('platform__type'))
asset_type_map = defaultdict(list)
for asset in assets:
asset_type_map[asset.tp].append(asset)
tp = self.tp
if tp:
assets = asset_type_map.get(tp, [])
if not assets:
return []
pid = f'ROOT_{str(assets[0].category).upper()}_{tp}'
return self.serialize_assets(assets, pid=pid)
params = self.request.query_params
get_root = not list(filter(lambda x: params.get(x), ('type', 'n')))
resource_platforms = assets.order_by('id').values_list('platform_id', flat=True)
node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=get_root)
pattern = re.compile(r'\(0\)?')
nodes = []
for node in node_all:
meta = node.get('meta', {})
if pattern.search(node['name']) or meta.get('type') == 'platform':
continue
_type = meta.get('_type')
if _type:
node['type'] = _type
meta.setdefault('data', {})
node['meta'] = meta
nodes.append(node)
return [params.get('category'), params.get('type')]
if not self.is_sync:
return nodes
@lazyproperty
def query_asset_util(self):
return UserPermAssetUtil(self.user)
asset_nodes = []
for node in nodes:
node['open'] = True
tp = node.get('meta', {}).get('_type')
if not tp:
continue
assets = asset_type_map.get(tp, [])
asset_nodes += self.serialize_assets(assets, pid=node['id'])
return nodes + asset_nodes
@timeit
def get_assets(self):
return self.query_asset_util.get_all_assets()
def list(self, request, *args, **kwargs):
assets = self.get_assets()
nodes = self.to_tree_nodes(assets)
return Response(data=nodes)
def _get_tree_nodes_async(self):
if not self.tp or not all(self.tp):
nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user)
return nodes, []
category, tp = self.tp
assets = self.get_assets().filter(platform__type=tp, platform__category=category)
return [], assets
def _get_tree_nodes_sync(self):
if self.request.query_params.get('lv'):
return []
nodes = self.query_asset_util.get_type_nodes_tree()
return nodes, []
@property
def serialize_asset_kwargs(self):
return {
'get_pid': lambda asset, platform: 'ROOT_{}_{}'.format(platform.category.upper(), platform.type),
}
def serialize_nodes(self, nodes, with_asset_amount=False):
return nodes
def get_nodes_assets(self):
if self.is_sync:
return self._get_tree_nodes_sync()
else:
return self._get_tree_nodes_async()
def get_page_assets(self):
if self.is_sync:
return self.get_assets()
else:
return []
class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):

View File

@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _
from accounts.const import AliasAccount
from accounts.models import Account
from assets.models import Asset
from common.utils import date_expired_default
from common.utils import date_expired_default, lazyproperty
from common.utils.timezone import local_now
from labels.mixins import LabeledMixin
from orgs.mixins.models import JMSOrgBaseModel
@ -105,6 +105,22 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
return True
return False
@lazyproperty
def users_amount(self):
return self.users.count()
@lazyproperty
def user_groups_amount(self):
return self.user_groups.count()
@lazyproperty
def assets_amount(self):
return self.assets.count()
@lazyproperty
def nodes_amount(self):
return self.nodes.count()
def get_all_users(self):
from users.models import User
user_ids = self.users.all().values_list('id', flat=True)
@ -143,11 +159,14 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
@classmethod
def get_all_users_for_perms(cls, perm_ids, flat=False):
user_ids = cls.users.through.objects.filter(assetpermission_id__in=perm_ids) \
user_ids = cls.users.through.objects \
.filter(assetpermission_id__in=perm_ids) \
.values_list('user_id', flat=True).distinct()
group_ids = cls.user_groups.through.objects.filter(assetpermission_id__in=perm_ids) \
group_ids = cls.user_groups.through.objects \
.filter(assetpermission_id__in=perm_ids) \
.values_list('usergroup_id', flat=True).distinct()
group_user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \
group_user_ids = User.groups.through.objects \
.filter(usergroup_id__in=group_ids) \
.values_list('user_id', flat=True).distinct()
user_ids = set(user_ids) | set(group_user_ids)
if flat:

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.db.models import Q
from django.db.models import Q, Count
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
@ -14,7 +14,7 @@ from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from perms.models import ActionChoices, AssetPermission
from users.models import User, UserGroup
__all__ = ["AssetPermissionSerializer", "ActionChoicesField"]
__all__ = ["AssetPermissionSerializer", "ActionChoicesField", "AssetPermissionListSerializer"]
class ActionChoicesField(BitChoicesField):
@ -142,8 +142,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
def perform_display_create(instance, **kwargs):
# 用户
users_to_set = User.objects.filter(
Q(name__in=kwargs.get("users_display"))
| Q(username__in=kwargs.get("users_display"))
Q(name__in=kwargs.get("users_display")) |
Q(username__in=kwargs.get("users_display"))
).distinct()
instance.users.add(*users_to_set)
# 用户组
@ -153,8 +153,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
instance.user_groups.add(*user_groups_to_set)
# 资产
assets_to_set = Asset.objects.filter(
Q(address__in=kwargs.get("assets_display"))
| Q(name__in=kwargs.get("assets_display"))
Q(address__in=kwargs.get("assets_display")) |
Q(name__in=kwargs.get("assets_display"))
).distinct()
instance.assets.add(*assets_to_set)
# 节点
@ -180,3 +180,26 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
instance = super().create(validated_data)
self.perform_display_create(instance, **display)
return instance
class AssetPermissionListSerializer(AssetPermissionSerializer):
users_amount = serializers.IntegerField(read_only=True, label=_("Users amount"))
user_groups_amount = serializers.IntegerField(read_only=True, label=_("User groups amount"))
assets_amount = serializers.IntegerField(read_only=True, label=_("Assets amount"))
nodes_amount = serializers.IntegerField(read_only=True, label=_("Nodes amount"))
class Meta(AssetPermissionSerializer.Meta):
amount_fields = ["users_amount", "user_groups_amount", "assets_amount", "nodes_amount"]
remove_fields = {"users", "assets", "nodes", "user_groups"}
fields = list(set(AssetPermissionSerializer.Meta.fields + amount_fields) - remove_fields)
@classmethod
def setup_eager_loading(cls, queryset):
"""Perform necessary eager loading of data."""
queryset = queryset.annotate(
users_amount=Count("users"),
user_groups_amount=Count("user_groups"),
assets_amount=Count("assets"),
nodes_amount=Count("nodes"),
)
return queryset

View File

@ -3,15 +3,13 @@
from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save
from django.dispatch import receiver
from users.models import User, UserGroup
from assets.models import Asset
from common.utils import get_logger, get_object_or_none
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR
from common.exceptions import M2MReverseNotAllowed
from common.utils import get_logger, get_object_or_none
from perms.models import AssetPermission
from perms.utils import UserPermTreeExpireUtil
from users.models import User, UserGroup
logger = get_logger(__file__)
@ -38,7 +36,7 @@ def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs):
group = UserGroup.objects.get(id=list(group_ids)[0])
org_id = group.org_id
has_group_perm = AssetPermission.user_groups.through.objects\
has_group_perm = AssetPermission.user_groups.through.objects \
.filter(usergroup_id__in=group_ids).exists()
if not has_group_perm:
return
@ -115,6 +113,7 @@ def on_asset_permission_user_groups_changed(sender, instance, action, pk_set, re
def on_node_asset_change(action, instance, reverse, pk_set, **kwargs):
if not need_rebuild_mapping_node(action):
return
print("Asset node changed: ", action)
if reverse:
asset_ids = pk_set
node_ids = [instance.id]

View File

@ -1,8 +1,7 @@
from django.db.models import QuerySet
from assets.models import Node, Asset
from common.utils import get_logger
from common.utils import get_logger, timeit
from perms.models import AssetPermission
logger = get_logger(__file__)
@ -13,6 +12,7 @@ __all__ = ['AssetPermissionUtil']
class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """
@timeit
def get_permissions_for_user(self, user, with_group=True, flat=False):
""" 获取用户的授权规则 """
perm_ids = set()

View File

@ -1,13 +1,22 @@
from django.conf import settings
from django.db.models import Q
import json
import re
from django.conf import settings
from django.core.cache import cache
from django.db.models import Q
from rest_framework.utils.encoders import JSONEncoder
from assets.const import AllTypes
from assets.models import FavoriteAsset, Asset
from common.utils.common import timeit
from common.utils.common import timeit, get_logger
from orgs.utils import current_org, tmp_to_root_org
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
logger = get_logger(__name__)
class AssetPermissionPermAssetUtil:
@ -16,29 +25,32 @@ class AssetPermissionPermAssetUtil:
def get_all_assets(self):
""" 获取所有授权的资产 """
node_asset_ids = self.get_perm_nodes_assets(flat=True)
direct_asset_ids = self.get_direct_assets(flat=True)
asset_ids = list(node_asset_ids) + list(direct_asset_ids)
assets = Asset.objects.filter(id__in=asset_ids)
return assets
node_assets = self.get_perm_nodes_assets()
direct_assets = self.get_direct_assets()
# 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
return (node_assets | direct_assets).distinct()
@timeit
def get_perm_nodes_assets(self, flat=False):
""" 获取所有授权节点下的资产 """
from assets.models import Node
nodes = Node.objects.prefetch_related('granted_by_permissions').filter(
granted_by_permissions__in=self.perm_ids).only('id', 'key')
nodes = Node.objects \
.prefetch_related('granted_by_permissions') \
.filter(granted_by_permissions__in=self.perm_ids) \
.only('id', 'key')
assets = PermNode.get_nodes_all_assets(*nodes)
if flat:
return assets.values_list('id', flat=True)
return set(assets.values_list('id', flat=True))
return assets
@timeit
def get_direct_assets(self, flat=False):
""" 获取直接授权的资产 """
assets = Asset.objects.order_by() \
.filter(granted_by_permissions__id__in=self.perm_ids) \
.distinct()
if flat:
return assets.values_list('id', flat=True)
return set(assets.values_list('id', flat=True))
return assets
@ -52,12 +64,62 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
def get_ungroup_assets(self):
return self.get_direct_assets()
@timeit
def get_favorite_assets(self):
assets = self.get_all_assets()
assets = Asset.objects.all().valid()
asset_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True)
assets = assets.filter(id__in=list(asset_ids))
return assets
def get_type_nodes_tree(self):
assets = self.get_all_assets()
resource_platforms = assets.order_by('id').values_list('platform_id', flat=True)
node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=True)
pattern = re.compile(r'\(0\)?')
nodes = []
for node in node_all:
meta = node.get('meta', {})
if pattern.search(node['name']) or meta.get('type') == 'platform':
continue
_type = meta.get('_type')
if _type:
node['type'] = _type
node['category'] = meta.get('category')
meta.setdefault('data', {})
node['meta'] = meta
nodes.append(node)
return nodes
@classmethod
def get_type_nodes_tree_or_cached(cls, user):
key = f'perms:type-nodes-tree:{user.id}:{current_org.id}'
nodes = cache.get(key)
if nodes is None:
nodes = cls(user).get_type_nodes_tree()
nodes_json = json.dumps(nodes, cls=JSONEncoder)
cache.set(key, nodes_json, 60 * 60 * 24)
else:
nodes = json.loads(nodes)
return nodes
def refresh_type_nodes_tree_cache(self):
logger.debug("Refresh type nodes tree cache")
key = f'perms:type-nodes-tree:{self.user.id}:{current_org.id}'
cache.delete(key)
def refresh_favorite_assets(self):
favor_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True)
favor_ids = set(favor_ids)
with tmp_to_root_org():
valid_ids = self.get_all_assets() \
.filter(id__in=favor_ids) \
.values_list('id', flat=True)
valid_ids = set(valid_ids)
invalid_ids = favor_ids - valid_ids
FavoriteAsset.objects.filter(user=self.user, asset_id__in=invalid_ids).delete()
def get_node_assets(self, key):
node = PermNode.objects.get(key=key)
node.compute_node_from_and_assets_amount(self.user)
@ -134,7 +196,11 @@ class UserPermNodeUtil:
self.perm_ids = AssetPermissionUtil().get_permissions_for_user(self.user, flat=True)
def get_favorite_node(self):
assets_amount = UserPermAssetUtil(self.user).get_favorite_assets().count()
favor_ids = FavoriteAsset.objects \
.filter(user=self.user) \
.values_list('asset_id') \
.distinct()
assets_amount = Asset.objects.all().valid().filter(id__in=favor_ids).count()
return PermNode.get_favorite_node(assets_amount)
def get_ungrouped_node(self):

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from django.conf import settings
from django.core.cache import cache
from django.db import transaction
from assets.models import Asset
from assets.utils import NodeAssetsUtil
from common.db.models import output_as_string
from common.decorators import on_transaction_commit
from common.decorators import on_transaction_commit, merge_delay_run
from common.utils import get_logger
from common.utils.common import lazyproperty, timeit
from orgs.models import Organization
@ -23,6 +24,7 @@ from perms.models import (
PermNode
)
from users.models import User
from . import UserPermAssetUtil
from .permission import AssetPermissionUtil
logger = get_logger(__name__)
@ -50,24 +52,74 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
def __init__(self, user):
self.user = user
self.orgs = self.user.orgs.distinct()
self.org_ids = [str(o.id) for o in self.orgs]
@lazyproperty
def orgs(self):
return self.user.orgs.distinct()
@lazyproperty
def org_ids(self):
return [str(o.id) for o in self.orgs]
@lazyproperty
def cache_key_user(self):
return self.get_cache_key(self.user.id)
@lazyproperty
def cache_key_time(self):
key = 'perms.user.node_tree.built_time.{}'.format(self.user.id)
return key
@timeit
def refresh_if_need(self, force=False):
self._clean_user_perm_tree_for_legacy_org()
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
to_refresh_orgs = self.orgs if force else self._get_user_need_refresh_orgs()
if not to_refresh_orgs:
logger.info('Not have to refresh orgs')
return
with UserGrantedTreeRebuildLock(self.user.id):
logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets(users=(self.user,))
@timeit
def refresh_tree_manual(self):
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh just now, pass: {}'.format(built_just_now))
return
to_refresh_orgs = self._get_user_need_refresh_orgs()
if not to_refresh_orgs:
logger.info('Not have to refresh orgs for user: {}'.format(self.user))
return
self.perform_refresh_user_tree(to_refresh_orgs)
@timeit
def perform_refresh_user_tree(self, to_refresh_orgs):
# 再判断一次,毕竟构建树比较慢
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
self._clean_user_perm_tree_for_legacy_org()
ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
lock = UserGrantedTreeRebuildLock(self.user.id)
got = lock.acquire(blocking=False)
if not got:
logger.info('User perm tree rebuild lock not acquired, pass')
return
try:
for org in to_refresh_orgs:
self._rebuild_user_perm_tree_for_org(org)
self._mark_user_orgs_refresh_finished(to_refresh_orgs)
self._mark_user_orgs_refresh_finished(to_refresh_orgs)
finally:
lock.release()
def _rebuild_user_perm_tree_for_org(self, org):
with tmp_to_org(org):
@ -75,7 +127,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
UserPermTreeBuildUtil(self.user).rebuild_user_perm_tree()
end = time.time()
logger.info(
'Refresh user [{user}] org [{org}] perm tree, user {use_time:.2f}s'
'Refresh user perm tree: [{user}] org [{org}] {use_time:.2f}s'
''.format(user=self.user, org=org, use_time=end - start)
)
@ -90,7 +142,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
cached_org_ids = self.client.smembers(self.cache_key_user)
cached_org_ids = {oid.decode() for oid in cached_org_ids}
to_refresh_org_ids = set(self.org_ids) - cached_org_ids
to_refresh_orgs = Organization.objects.filter(id__in=to_refresh_org_ids)
to_refresh_orgs = list(Organization.objects.filter(id__in=to_refresh_org_ids))
logger.info(f'Need to refresh orgs: {to_refresh_orgs}')
return to_refresh_orgs
@ -128,7 +180,8 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
self.expire_perm_tree_for_user_groups_orgs(group_ids, org_ids)
def expire_perm_tree_for_user_groups_orgs(self, group_ids, org_ids):
user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \
user_ids = User.groups.through.objects \
.filter(usergroup_id__in=group_ids) \
.values_list('user_id', flat=True).distinct()
self.expire_perm_tree_for_users_orgs(user_ids, org_ids)
@ -151,6 +204,21 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
logger.info('Expire all user perm tree')
@merge_delay_run(ttl=20)
def refresh_user_orgs_perm_tree(user_orgs=()):
for user, orgs in user_orgs:
util = UserPermTreeRefreshUtil(user)
util.perform_refresh_user_tree(orgs)
@merge_delay_run(ttl=20)
def refresh_user_favorite_assets(users=()):
for user in users:
util = UserPermAssetUtil(user)
util.refresh_favorite_assets()
util.refresh_type_nodes_tree_cache()
class UserPermTreeBuildUtil(object):
node_only_fields = ('id', 'key', 'parent_key', 'org_id')
@ -161,13 +229,14 @@ class UserPermTreeBuildUtil(object):
self._perm_nodes_key_node_mapper = {}
def rebuild_user_perm_tree(self):
self.clean_user_perm_tree()
if not self.user_perm_ids:
logger.info('User({}) not have permissions'.format(self.user))
return
self.compute_perm_nodes()
self.compute_perm_nodes_asset_amount()
self.create_mapping_nodes()
with transaction.atomic():
self.clean_user_perm_tree()
if not self.user_perm_ids:
logger.info('User({}) not have permissions'.format(self.user))
return
self.compute_perm_nodes()
self.compute_perm_nodes_asset_amount()
self.create_mapping_nodes()
def clean_user_perm_tree(self):
UserAssetGrantedTreeNodeRelation.objects.filter(user=self.user).delete()

View File

@ -139,7 +139,7 @@ class RBACPermission(permissions.DjangoModelPermissions):
if isinstance(perms, str):
perms = [perms]
has = request.user.has_perms(perms)
logger.debug('View require perms: {}, result: {}'.format(perms, has))
logger.debug('Api require perms: {}, result: {}'.format(perms, has))
return has
def has_object_permission(self, request, view, obj):

View File

@ -6,7 +6,7 @@ from rest_framework.response import Response
from orgs.mixins.api import OrgBulkModelViewSet
from ..models import UserGroup, User
from ..serializers import UserGroupSerializer
from ..serializers import UserGroupSerializer, UserGroupListSerializer
__all__ = ['UserGroupViewSet']
@ -15,7 +15,10 @@ class UserGroupViewSet(OrgBulkModelViewSet):
model = UserGroup
filterset_fields = ("name",)
search_fields = filterset_fields
serializer_class = UserGroupSerializer
serializer_classes = {
'default': UserGroupSerializer,
'list': UserGroupListSerializer,
}
ordering = ('name',)
rbac_perms = (
("add_all_users", "users.add_usergroup"),

View File

@ -2,6 +2,7 @@
#
from django.db.models import Count
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.serializers.fields import ObjectRelatedField
from common.serializers.mixin import ResourceLabelsMixin
@ -10,7 +11,7 @@ from .. import utils
from ..models import User, UserGroup
__all__ = [
'UserGroupSerializer',
'UserGroupSerializer', 'UserGroupListSerializer',
]
@ -29,7 +30,6 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
fields = fields_mini + fields_small + ['users', 'labels']
extra_kwargs = {
'created_by': {'label': _('Created by'), 'read_only': True},
'users_amount': {'label': _('Users amount')},
'id': {'label': _('ID')},
}
@ -45,6 +45,17 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('users', 'labels', 'labels__label') \
queryset = queryset.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count('users'))
return queryset
class UserGroupListSerializer(UserGroupSerializer):
users_amount = serializers.IntegerField(label=_('Users amount'), read_only=True)
class Meta(UserGroupSerializer.Meta):
fields = list(set(UserGroupSerializer.Meta.fields + ['users_amount']) - {'users'})
extra_kwargs = {
**UserGroupSerializer.Meta.extra_kwargs,
'users_amount': {'label': _('Users amount')},
}

View File

@ -17,6 +17,7 @@ from resources.assets import AssetsGenerator, NodesGenerator, PlatformGenerator
from resources.users import UserGroupGenerator, UserGenerator
from resources.perms import AssetPermissionGenerator
from resources.terminal import CommandGenerator, SessionGenerator
from resources.accounts import AccountGenerator
resource_generator_mapper = {
'asset': AssetsGenerator,
@ -27,6 +28,7 @@ resource_generator_mapper = {
'asset_permission': AssetPermissionGenerator,
'command': CommandGenerator,
'session': SessionGenerator,
'account': AccountGenerator,
'all': None
# 'stat': StatGenerator
}
@ -45,6 +47,7 @@ def main():
parser.add_argument('-o', '--org', type=str, default='')
args = parser.parse_args()
resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org
resource = resource.lower().rstrip('s')
generator_cls = []
if resource == 'all':

View File

@ -0,0 +1,32 @@
import random
import forgery_py
from accounts.models import Account
from assets.models import Asset
from .base import FakeDataGenerator
class AccountGenerator(FakeDataGenerator):
resource = 'account'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.assets = list(list(Asset.objects.all()[:5000]))
def do_generate(self, batch, batch_size):
accounts = []
for i in batch:
asset = random.choice(self.assets)
name = forgery_py.internet.user_name(True) + '-' + str(i)
d = {
'username': name,
'name': name,
'asset': asset,
'secret': name,
'secret_type': 'password',
'is_active': True,
'privileged': False,
}
accounts.append(Account(**d))
Account.objects.bulk_create(accounts, ignore_conflicts=True)

View File

@ -48,7 +48,7 @@ class AssetsGenerator(FakeDataGenerator):
def pre_generate(self):
self.node_ids = list(Node.objects.all().values_list('id', flat=True))
self.platform_ids = list(Platform.objects.all().values_list('id', flat=True))
self.platform_ids = list(Platform.objects.filter(category='host').values_list('id', flat=True))
def set_assets_nodes(self, assets):
for asset in assets:
@ -72,6 +72,17 @@ class AssetsGenerator(FakeDataGenerator):
assets.append(Asset(**data))
creates = Asset.objects.bulk_create(assets, ignore_conflicts=True)
self.set_assets_nodes(creates)
self.set_asset_platform(creates)
@staticmethod
def set_asset_platform(assets):
protocol = random.choice(['ssh', 'rdp', 'telnet', 'vnc'])
protocols = []
for asset in assets:
port = 22 if protocol == 'ssh' else 3389
protocols.append(Protocol(asset=asset, name=protocol, port=port))
Protocol.objects.bulk_create(protocols, ignore_conflicts=True)
def after_generate(self):
pass

View File

@ -41,7 +41,7 @@ class FakeDataGenerator:
start = time.time()
self.do_generate(batch, self.batch_size)
end = time.time()
using = end - start
using = round(end - start, 3)
from_size = created
created += len(batch)
print('Generate %s: %s-%s [%s]' % (self.resource, from_size, created, using))

View File

@ -1,9 +1,11 @@
from random import choice, sample
from random import sample
import forgery_py
from .base import FakeDataGenerator
from orgs.utils import current_org
from rbac.models import RoleBinding, Role
from users.models import *
from .base import FakeDataGenerator
class UserGroupGenerator(FakeDataGenerator):
@ -47,3 +49,12 @@ class UserGenerator(FakeDataGenerator):
users.append(u)
users = User.objects.bulk_create(users, ignore_conflicts=True)
self.set_groups(users)
self.set_to_org(users)
def set_to_org(self, users):
bindings = []
role = Role.objects.get(name='OrgUser')
for u in users:
b = RoleBinding(user=u, role=role, org_id=current_org.id, scope='org')
bindings.append(b)
RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True)