perf: 优化 tree nodes 避免太慢 (#12472)

* perf: 优化 tree nodes 避免太慢

perf: 优化大量资产上的资产数生成比较慢

perf: 优化节点树

perf: 修改 tree nooooooooodes

perf: 优化一些 api 比较大的问题

perf: 优化平台 api

perf: 分页返回同步树

perf: 优化节点树

perf: 深度优化节点树

* perf: remove unused config

---------

Co-authored-by: ibuler <ibuler@qq.com>
pull/12481/head
fit2bot 2024-01-02 16:11:56 +08:00 committed by GitHub
parent e80a0e41ba
commit 2fcbfe9f21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 508 additions and 236 deletions

View File

@ -21,7 +21,6 @@ from common.drf.filters import BaseFilterSet, AttrRulesFilterBackend
from common.utils import get_logger, is_uuid from common.utils import get_logger, is_uuid
from orgs.mixins import generics from orgs.mixins import generics
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from ..mixin import NodeFilterMixin
from ...notifications import BulkUpdatePlatformSkipAssetUserMsg from ...notifications import BulkUpdatePlatformSkipAssetUserMsg
logger = get_logger(__file__) logger = get_logger(__file__)
@ -86,7 +85,7 @@ class AssetFilterSet(BaseFilterSet):
return queryset.filter(protocols__name__in=value).distinct() return queryset.filter(protocols__name__in=value).distinct()
class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet): class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
""" """
API endpoint that allows Asset to be viewed or edited. API endpoint that allows Asset to be viewed or edited.
""" """
@ -114,9 +113,7 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
] ]
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() \ queryset = super().get_queryset()
.prefetch_related('nodes', 'protocols') \
.select_related('platform', 'domain')
if queryset.model is not Asset: if queryset.model is not Asset:
queryset = queryset.select_related('asset_ptr') queryset = queryset.select_related('asset_ptr')
return queryset return queryset

View File

@ -20,14 +20,15 @@ class DomainViewSet(OrgBulkModelViewSet):
filterset_fields = ("name",) filterset_fields = ("name",)
search_fields = filterset_fields search_fields = filterset_fields
ordering = ('name',) ordering = ('name',)
serializer_classes = {
'default': serializers.DomainSerializer,
'list': serializers.DomainListSerializer,
}
def get_serializer_class(self): def get_serializer_class(self):
if self.request.query_params.get('gateway'): if self.request.query_params.get('gateway'):
return serializers.DomainWithGatewaySerializer return serializers.DomainWithGatewaySerializer
return serializers.DomainSerializer return super().get_serializer_class()
def get_queryset(self):
return super().get_queryset().prefetch_related('assets')
class GatewayViewSet(HostViewSet): class GatewayViewSet(HostViewSet):

View File

@ -2,7 +2,7 @@ from typing import List
from rest_framework.request import Request from rest_framework.request import Request
from assets.models import Node, Protocol from assets.models import Node, Platform, Protocol
from assets.utils import get_node_from_request, is_query_node_all_assets from assets.utils import get_node_from_request, is_query_node_all_assets
from common.utils import lazyproperty, timeit from common.utils import lazyproperty, timeit
@ -71,37 +71,43 @@ class SerializeToTreeNodeMixin:
return 'file' return 'file'
@timeit @timeit
def serialize_assets(self, assets, node_key=None, pid=None): def serialize_assets(self, assets, node_key=None, get_pid=None):
if node_key is None: if not get_pid and not node_key:
get_pid = lambda asset: getattr(asset, 'parent_key', '') get_pid = lambda asset, platform: getattr(asset, 'parent_key', '')
else:
get_pid = lambda asset: node_key
sftp_asset_ids = Protocol.objects.filter(name='sftp') \ sftp_asset_ids = Protocol.objects.filter(name='sftp') \
.values_list('asset_id', flat=True) .values_list('asset_id', flat=True)
sftp_asset_ids = list(sftp_asset_ids) sftp_asset_ids = set(sftp_asset_ids)
data = [ platform_map = {p.id: p for p in Platform.objects.all()}
{
data = []
for asset in assets:
platform = platform_map.get(asset.platform_id)
if not platform:
continue
pid = node_key or get_pid(asset, platform)
if not pid or pid.isdigit():
continue
data.append({
'id': str(asset.id), 'id': str(asset.id),
'name': asset.name, 'name': asset.name,
'title': f'{asset.address}\n{asset.comment}', 'title': f'{asset.address}\n{asset.comment}'.strip(),
'pId': pid or get_pid(asset), 'pId': pid,
'isParent': False, 'isParent': False,
'open': False, 'open': False,
'iconSkin': self.get_icon(asset), 'iconSkin': self.get_icon(platform),
'chkDisabled': not asset.is_active, 'chkDisabled': not asset.is_active,
'meta': { 'meta': {
'type': 'asset', 'type': 'asset',
'data': { 'data': {
'platform_type': asset.platform.type, 'platform_type': platform.type,
'org_name': asset.org_name, 'org_name': asset.org_name,
'sftp': asset.id in sftp_asset_ids, 'sftp': asset.id in sftp_asset_ids,
'name': asset.name, 'name': asset.name,
'address': asset.address 'address': asset.address
}, },
} }
} })
for asset in assets
]
return data return data

View File

@ -29,7 +29,9 @@ class AssetPlatformViewSet(JMSModelViewSet):
} }
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset().prefetch_related(
'protocols', 'automation'
)
queryset = queryset.filter(type__in=AllTypes.get_types_values()) queryset = queryset.filter(type__in=AllTypes.get_types_values())
return queryset return queryset

View File

@ -126,6 +126,8 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
include_assets = self.request.query_params.get('assets', '0') == '1' include_assets = self.request.query_params.get('assets', '0') == '1'
if not self.instance or not include_assets: if not self.instance or not include_assets:
return Asset.objects.none() return Asset.objects.none()
if self.instance.is_org_root():
return Asset.objects.none()
if query_all: if query_all:
assets = self.instance.get_all_assets() assets = self.instance.get_all_assets()
else: else:

View File

@ -268,7 +268,7 @@ class AllTypes(ChoicesMixin):
meta = {'type': 'category', 'category': category.value, '_type': category.value} meta = {'type': 'category', 'category': category.value, '_type': category.value}
category_node = cls.choice_to_node(category, 'ROOT', meta=meta) category_node = cls.choice_to_node(category, 'ROOT', meta=meta)
category_count = category_type_mapper.get(category, 0) category_count = category_type_mapper.get(category, 0)
category_node['name'] += f'({category_count})' category_node['name'] += f' ({category_count})'
nodes.append(category_node) nodes.append(category_node)
# Type 格式化 # Type 格式化
@ -277,7 +277,7 @@ class AllTypes(ChoicesMixin):
meta = {'type': 'type', 'category': category.value, '_type': tp.value} meta = {'type': 'type', 'category': category.value, '_type': tp.value}
tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta) tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta)
tp_count = category_type_mapper.get(category + '_' + tp, 0) tp_count = category_type_mapper.get(category + '_' + tp, 0)
tp_node['name'] += f'({tp_count})' tp_node['name'] += f' ({tp_count})'
platforms = tp_platforms.get(category + '_' + tp, []) platforms = tp_platforms.get(category + '_' + tp, [])
if not platforms: if not platforms:
tp_node['isParent'] = False tp_node['isParent'] = False
@ -286,7 +286,7 @@ class AllTypes(ChoicesMixin):
# Platform 格式化 # Platform 格式化
for p in platforms: for p in platforms:
platform_node = cls.platform_to_node(p, tp_node['id'], include_asset) platform_node = cls.platform_to_node(p, tp_node['id'], include_asset)
platform_node['name'] += f'({platform_count.get(p.id, 0)})' platform_node['name'] += f' ({platform_count.get(p.id, 0)})'
nodes.append(platform_node) nodes.append(platform_node)
return nodes return nodes

View File

@ -63,11 +63,10 @@ class NodeFilterBackend(filters.BaseFilterBackend):
query_all = is_query_node_all_assets(request) query_all = is_query_node_all_assets(request)
if query_all: if query_all:
return queryset.filter( return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key) Q(nodes__key=node.key)
).distinct() ).distinct()
else: else:
print("Query query origin: ", queryset.count())
return queryset.filter(nodes__key=node.key).distinct() return queryset.filter(nodes__key=node.key).distinct()

View File

@ -13,7 +13,7 @@ from django.db.transaction import atomic
from django.utils.translation import gettext_lazy as _, gettext from django.utils.translation import gettext_lazy as _, gettext
from common.db.models import output_as_string from common.db.models import output_as_string
from common.utils import get_logger from common.utils import get_logger, timeit
from common.utils.lock import DistributedLock from common.utils.lock import DistributedLock
from orgs.mixins.models import OrgManager, JMSOrgBaseModel from orgs.mixins.models import OrgManager, JMSOrgBaseModel
from orgs.models import Organization from orgs.models import Organization
@ -195,11 +195,6 @@ class FamilyMixin:
ancestor_keys = self.get_ancestor_keys(with_self=with_self) ancestor_keys = self.get_ancestor_keys(with_self=with_self)
return self.__class__.objects.filter(key__in=ancestor_keys) return self.__class__.objects.filter(key__in=ancestor_keys)
# @property
# def parent_key(self):
# parent_key = ":".join(self.key.split(":")[:-1])
# return parent_key
def compute_parent_key(self): def compute_parent_key(self):
return compute_parent_key(self.key) return compute_parent_key(self.key)
@ -349,29 +344,26 @@ class NodeAllAssetsMappingMixin:
return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id) return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id)
@classmethod @classmethod
@timeit
def generate_node_all_asset_ids_mapping(cls, org_id): def generate_node_all_asset_ids_mapping(cls, org_id):
from .asset import Asset logger.info(f'Generate node asset mapping: org_id={org_id}')
logger.info(f'Generate node asset mapping: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
t1 = time.time() t1 = time.time()
with tmp_to_org(org_id): with tmp_to_org(org_id):
node_ids_key = Node.objects.annotate( node_ids_key = Node.objects.annotate(
char_id=output_as_string('id') char_id=output_as_string('id')
).values_list('char_id', 'key') ).values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = Asset.nodes.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
node_id_ancestor_keys_mapping = { node_id_ancestor_keys_mapping = {
node_id: cls.get_node_ancestor_keys(node_key, with_self=True) node_id: cls.get_node_ancestor_keys(node_key, with_self=True)
for node_id, node_key in node_ids_key for node_id, node_key in node_ids_key
} }
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = cls.assets.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
nodeid_assetsid_mapping = defaultdict(set) nodeid_assetsid_mapping = defaultdict(set)
for node_id, asset_id in nodes_asset_ids: for node_id, asset_id in nodes_asset_ids:
nodeid_assetsid_mapping[node_id].add(asset_id) nodeid_assetsid_mapping[node_id].add(asset_id)
@ -386,7 +378,7 @@ class NodeAllAssetsMappingMixin:
mapping[ancestor_key].update(asset_ids) mapping[ancestor_key].update(asset_ids)
t3 = time.time() t3 = time.time()
logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2 - t1, t3 - t2)) logger.info('Generate asset nodes mapping, DB query: {:.2f}s, mapping: {:.2f}s'.format(t2 - t1, t3 - t2))
return mapping return mapping
@ -436,6 +428,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
return asset_ids return asset_ids
@classmethod @classmethod
@timeit
def get_nodes_all_assets(cls, *nodes): def get_nodes_all_assets(cls, *nodes):
from .asset import Asset from .asset import Asset
node_ids = set() node_ids = set()
@ -559,11 +552,6 @@ class Node(JMSOrgBaseModel, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
def __str__(self): def __str__(self):
return self.full_value return self.full_value
# def __eq__(self, other):
# if not other:
# return False
# return self.id == other.id
#
def __gt__(self, other): def __gt__(self, other):
self_key = [int(k) for k in self.key.split(':')] self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')] other_key = [int(k) for k in other.key.split(':')]

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models import Count
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
@ -7,18 +8,15 @@ from common.serializers import ResourceLabelsMixin
from common.serializers.fields import ObjectRelatedField from common.serializers.fields import ObjectRelatedField
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .gateway import GatewayWithAccountSecretSerializer from .gateway import GatewayWithAccountSecretSerializer
from ..models import Domain, Asset from ..models import Domain
__all__ = ['DomainSerializer', 'DomainWithGatewaySerializer'] __all__ = ['DomainSerializer', 'DomainWithGatewaySerializer', 'DomainListSerializer']
class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer): class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
gateways = ObjectRelatedField( gateways = ObjectRelatedField(
many=True, required=False, label=_('Gateway'), read_only=True, many=True, required=False, label=_('Gateway'), read_only=True,
) )
assets = ObjectRelatedField(
many=True, required=False, queryset=Asset.objects, label=_('Asset')
)
class Meta: class Meta:
model = Domain model = Domain
@ -30,7 +28,9 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
def to_representation(self, instance): def to_representation(self, instance):
data = super().to_representation(instance) data = super().to_representation(instance)
assets = data['assets'] assets = data.get('assets')
if assets is None:
return data
gateway_ids = [str(i['id']) for i in data['gateways']] gateway_ids = [str(i['id']) for i in data['gateways']]
data['assets'] = [i for i in assets if str(i['id']) not in gateway_ids] data['assets'] = [i for i in assets if str(i['id']) not in gateway_ids]
return data return data
@ -49,6 +49,20 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
return queryset return queryset
class DomainListSerializer(DomainSerializer):
assets_amount = serializers.IntegerField(label=_('Assets amount'), read_only=True)
class Meta(DomainSerializer.Meta):
fields = list(set(DomainSerializer.Meta.fields + ['assets_amount']) - {'assets'})
@classmethod
def setup_eager_loading(cls, queryset):
queryset = queryset.annotate(
assets_amount=Count('assets'),
)
return queryset
class DomainWithGatewaySerializer(serializers.ModelSerializer): class DomainWithGatewaySerializer(serializers.ModelSerializer):
gateways = GatewayWithAccountSecretSerializer(many=True, read_only=True) gateways = GatewayWithAccountSecretSerializer(many=True, read_only=True)

View File

@ -80,10 +80,11 @@ RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset) @receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs): def on_asset_delete(instance: Asset, using, **kwargs):
logger.debug("Asset pre delete signal recv: {}".format(instance))
node_ids = Node.objects.filter(assets=instance) \ node_ids = Node.objects.filter(assets=instance) \
.distinct().values_list('id', flat=True) .distinct().values_list('id', flat=True)
setattr(instance, RELATED_NODE_IDS, node_ids) node_ids = list(node_ids)
logger.debug("Asset pre delete signal recv: {}, node_ids: {}".format(instance, node_ids))
setattr(instance, RELATED_NODE_IDS, list(node_ids))
m2m_changed.send( m2m_changed.send(
sender=Asset.nodes.through, instance=instance, sender=Asset.nodes.through, instance=instance,
reverse=False, model=Node, pk_set=node_ids, reverse=False, model=Node, pk_set=node_ids,
@ -93,8 +94,8 @@ def on_asset_delete(instance: Asset, using, **kwargs):
@receiver(post_delete, sender=Asset) @receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs): def on_asset_post_delete(instance: Asset, using, **kwargs):
logger.debug("Asset post delete signal recv: {}".format(instance))
node_ids = getattr(instance, RELATED_NODE_IDS, []) node_ids = getattr(instance, RELATED_NODE_IDS, [])
logger.debug("Asset post delete signal recv: {}, node_ids: {}".format(instance, node_ids))
if node_ids: if node_ids:
m2m_changed.send( m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False, sender=Asset.nodes.through, instance=instance, reverse=False,

View File

@ -15,8 +15,8 @@ from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__) logger = get_logger(__file__)
@on_transaction_commit
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
@on_transaction_commit
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs): def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set` # 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed) # [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
@ -37,7 +37,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
update_nodes_assets_amount(node_ids=node_ids) update_nodes_assets_amount(node_ids=node_ids)
@merge_delay_run(ttl=5) @merge_delay_run(ttl=30)
def update_nodes_assets_amount(node_ids=()): def update_nodes_assets_amount(node_ids=()):
nodes = Node.objects.filter(id__in=node_ids) nodes = Node.objects.filter(id__in=node_ids)
nodes = Node.get_ancestor_queryset(nodes) nodes = Node.get_ancestor_queryset(nodes)

View File

@ -21,7 +21,7 @@ logger = get_logger(__name__)
node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)() node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)()
@merge_delay_run(ttl=5) @merge_delay_run(ttl=30)
def expire_node_assets_mapping(org_ids=()): def expire_node_assets_mapping(org_ids=()):
logger.debug("Recv asset nodes changed signal, expire memery node asset mapping") logger.debug("Recv asset nodes changed signal, expire memery node asset mapping")
# 所有进程清除(自己的 memory 数据) # 所有进程清除(自己的 memory 数据)
@ -53,8 +53,9 @@ def on_node_post_delete(sender, instance, **kwargs):
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs): def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
expire_node_assets_mapping(org_ids=(instance.org_id,)) if action.startswith('post'):
expire_node_assets_mapping(org_ids=(instance.org_id,))
@receiver(django_ready) @receiver(django_ready)

View File

@ -98,12 +98,17 @@ class QuerySetMixin:
return queryset return queryset
if self.action == 'metadata': if self.action == 'metadata':
queryset = queryset.none() queryset = queryset.none()
if self.action in ['list', 'metadata']:
serializer_class = self.get_serializer_class()
if serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
queryset = serializer_class.setup_eager_loading(queryset)
return queryset return queryset
def paginate_queryset(self, queryset):
page = super().paginate_queryset(queryset)
serializer_class = self.get_serializer_class()
if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
ids = [i.id for i in page]
page = self.get_queryset().filter(id__in=ids)
page = serializer_class.setup_eager_loading(page)
return page
class ExtraFilterFieldsMixin: class ExtraFilterFieldsMixin:
""" """

View File

@ -65,7 +65,7 @@ class EventLoopThread(threading.Thread):
_loop_thread = EventLoopThread() _loop_thread = EventLoopThread()
_loop_thread.setDaemon(True) _loop_thread.daemon = True
_loop_thread.start() _loop_thread.start()
executor = ThreadPoolExecutor( executor = ThreadPoolExecutor(
max_workers=10, max_workers=10,

View File

@ -62,7 +62,7 @@ def digest_sql_query():
method = current_request.method method = current_request.method
path = current_request.get_full_path() path = current_request.get_full_path()
print(">>> [{}] {}".format(method, path)) print(">>>. [{}] {}".format(method, path))
for table_name, queries in table_queries.items(): for table_name, queries in table_queries.items():
if table_name.startswith('rbac_') or table_name.startswith('auth_permission'): if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
continue continue
@ -77,9 +77,9 @@ def digest_sql_query():
sql = query['sql'] sql = query['sql']
if not sql or not sql.startswith('SELECT'): if not sql or not sql.startswith('SELECT'):
continue continue
print('\t{}. {}'.format(i, sql)) print('\t{}.[{}s] {}'.format(i, round(float(query['time']), 2), sql[:1000]))
logger.debug(">>> [{}] {}".format(method, path)) # logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters: for name, counter in counters:
logger.debug("Query {:3} times using {:.2f}s {}".format( logger.debug("Query {:3} times using {:.2f}s {}".format(
counter.counter, counter.time, name) counter.counter, counter.time, name)

View File

@ -220,7 +220,7 @@ def timeit(func):
now = time.time() now = time.time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
using = (time.time() - now) * 1000 using = (time.time() - now) * 1000
msg = "End call {}, using: {:.1f}ms".format(name, using) msg = "Ends call: {}, using: {:.1f}ms".format(name, using)
logger.debug(msg) logger.debug(msg)
return result return result

View File

@ -1,18 +1,16 @@
from functools import wraps
import threading import threading
from functools import wraps
from django.db import transaction
from redis_lock import ( from redis_lock import (
Lock as RedisLock, NotAcquired, UNLOCK_SCRIPT, Lock as RedisLock, NotAcquired, UNLOCK_SCRIPT,
EXTEND_SCRIPT, RESET_SCRIPT, RESET_ALL_SCRIPT EXTEND_SCRIPT, RESET_SCRIPT, RESET_ALL_SCRIPT
) )
from redis import Redis
from django.db import transaction
from common.utils import get_logger
from common.utils.inspect import copy_function_args
from common.utils.connection import get_redis_client
from jumpserver.const import CONFIG
from common.local import thread_local from common.local import thread_local
from common.utils import get_logger
from common.utils.connection import get_redis_client
from common.utils.inspect import copy_function_args
logger = get_logger(__file__) logger = get_logger(__file__)
@ -76,6 +74,7 @@ class DistributedLock(RedisLock):
# 要创建一个新的锁对象 # 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy): with self.__class__(**self.kwargs_copy):
return func(*args, **kwds) return func(*args, **kwds)
return inner return inner
@classmethod @classmethod
@ -95,7 +94,6 @@ class DistributedLock(RedisLock):
if self.locked(): if self.locked():
owner_id = self.get_owner_id() owner_id = self.get_owner_id()
local_owner_id = getattr(thread_local, self.name, None) local_owner_id = getattr(thread_local, self.name, None)
if local_owner_id and owner_id == local_owner_id: if local_owner_id and owner_id == local_owner_id:
return True return True
return False return False
@ -140,14 +138,16 @@ class DistributedLock(RedisLock):
logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
return return
else: else:
self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') self._raise_exc_with_log(
f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
def _release_on_reentrant_locked_by_me(self): def _release_on_reentrant_locked_by_me(self):
logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name}') logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name}')
id = getattr(thread_local, self.name, None) id = getattr(thread_local, self.name, None)
if id != self.id: if id != self.id:
raise PermissionError(f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') raise PermissionError(
f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
try: try:
# 这里要保证先删除 thread_local 的标记, # 这里要保证先删除 thread_local 的标记,
delattr(thread_local, self.name) delattr(thread_local, self.name)
@ -191,7 +191,7 @@ class DistributedLock(RedisLock):
# 处理是否在事务提交时才释放锁 # 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit: if self._release_on_transaction_commit:
logger.debug( logger.debug(
f'Release lock on transaction commit ... :lock_id={self.id} lock={self.name}') f'Release lock on transaction commit:lock_id={self.id} lock={self.name}')
transaction.on_commit(_release) transaction.on_commit(_release)
else: else:
_release() _release()

View File

@ -531,6 +531,7 @@ class Config(dict):
'SYSLOG_SOCKTYPE': 2, 'SYSLOG_SOCKTYPE': 2,
'PERM_EXPIRED_CHECK_PERIODIC': 60 * 60, 'PERM_EXPIRED_CHECK_PERIODIC': 60 * 60,
'PERM_TREE_REGEN_INTERVAL': 1,
'FLOWER_URL': "127.0.0.1:5555", 'FLOWER_URL': "127.0.0.1:5555",
'LANGUAGE_CODE': 'zh', 'LANGUAGE_CODE': 'zh',
'TIME_ZONE': 'Asia/Shanghai', 'TIME_ZONE': 'Asia/Shanghai',

View File

@ -208,6 +208,7 @@ OPERATE_LOG_ELASTICSEARCH_CONFIG = CONFIG.OPERATE_LOG_ELASTICSEARCH_CONFIG
MAX_LIMIT_PER_PAGE = CONFIG.MAX_LIMIT_PER_PAGE MAX_LIMIT_PER_PAGE = CONFIG.MAX_LIMIT_PER_PAGE
DEFAULT_PAGE_SIZE = CONFIG.DEFAULT_PAGE_SIZE DEFAULT_PAGE_SIZE = CONFIG.DEFAULT_PAGE_SIZE
PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
# Magnus DB Port # Magnus DB Port
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS

View File

@ -21,7 +21,7 @@ LOGGING = {
}, },
'main': { 'main': {
'datefmt': '%Y-%m-%d %H:%M:%S', 'datefmt': '%Y-%m-%d %H:%M:%S',
'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s', 'format': '%(asctime)s [%(levelname).4s] %(message)s',
}, },
'exception': { 'exception': {
'datefmt': '%Y-%m-%d %H:%M:%S', 'datefmt': '%Y-%m-%d %H:%M:%S',

View File

@ -75,7 +75,7 @@ model_cache_field_mapper = {
class OrgResourceStatisticsRefreshUtil: class OrgResourceStatisticsRefreshUtil:
@staticmethod @staticmethod
@merge_delay_run(ttl=5) @merge_delay_run(ttl=30)
def refresh_org_fields(org_fields=()): def refresh_org_fields(org_fields=()):
for org, cache_field_name in org_fields: for org, cache_field_name in org_fields:
OrgResourceStatisticsCache(org).expire(*cache_field_name) OrgResourceStatisticsCache(org).expire(*cache_field_name)
@ -104,7 +104,7 @@ def on_post_delete_refresh_org_resource_statistics_cache(sender, instance, **kwa
def _refresh_session_org_resource_statistics_cache(instance: Session): def _refresh_session_org_resource_statistics_cache(instance: Session):
cache_field_name = [ cache_field_name = [
'total_count_online_users', 'total_count_online_sessions', 'total_count_online_users', 'total_count_online_sessions',
'total_count_today_active_assets','total_count_today_failed_sessions' 'total_count_today_active_assets', 'total_count_today_failed_sessions'
] ]
org_cache = OrgResourceStatisticsCache(instance.org) org_cache = OrgResourceStatisticsCache(instance.org)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from perms import serializers from perms import serializers
from perms.filters import AssetPermissionFilter from perms.filters import AssetPermissionFilter
@ -13,7 +14,10 @@ class AssetPermissionViewSet(OrgBulkModelViewSet):
资产授权列表的增删改查api 资产授权列表的增删改查api
""" """
model = AssetPermission model = AssetPermission
serializer_class = serializers.AssetPermissionSerializer serializer_classes = {
'default': serializers.AssetPermissionSerializer,
'list': serializers.AssetPermissionListSerializer,
}
filterset_class = AssetPermissionFilter filterset_class = AssetPermissionFilter
search_fields = ('name',) search_fields = ('name',)
ordering = ('name',) ordering = ('name',)

View File

@ -1,16 +1,14 @@
from django.conf import settings from django.conf import settings
from rest_framework.response import Response from rest_framework.response import Response
from assets.models import Asset
from assets.api import SerializeToTreeNodeMixin from assets.api import SerializeToTreeNodeMixin
from assets.models import Asset
from common.utils import get_logger from common.utils import get_logger
from ..assets import UserAllPermedAssetsApi
from .mixin import RebuildTreeMixin from .mixin import RebuildTreeMixin
from ..assets import UserAllPermedAssetsApi
logger = get_logger(__name__) logger = get_logger(__name__)
__all__ = [ __all__ = [
'UserAllPermedAssetsAsTreeApi', 'UserAllPermedAssetsAsTreeApi',
'UserUngroupAssetsAsTreeApi', 'UserUngroupAssetsAsTreeApi',
@ -31,7 +29,7 @@ class AssetTreeMixin(RebuildTreeMixin, SerializeToTreeNodeMixin):
if request.query_params.get('search'): if request.query_params.get('search'):
""" 限制返回数量, 搜索的条件不精准时,会返回大量的无意义数据 """ """ 限制返回数量, 搜索的条件不精准时,会返回大量的无意义数据 """
assets = assets[:999] assets = assets[:999]
data = self.serialize_assets(assets, None) data = self.serialize_assets(assets, 'root')
return Response(data=data) return Response(data=data)
@ -42,6 +40,7 @@ class UserAllPermedAssetsAsTreeApi(AssetTreeMixin, UserAllPermedAssetsApi):
class UserUngroupAssetsAsTreeApi(UserAllPermedAssetsAsTreeApi): class UserUngroupAssetsAsTreeApi(UserAllPermedAssetsAsTreeApi):
""" 用户 '未分组节点的资产(直接授权的资产)' 作为树 """ """ 用户 '未分组节点的资产(直接授权的资产)' 作为树 """
def get_assets(self): def get_assets(self):
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return super().get_assets() return super().get_assets()

View File

@ -1,6 +1,4 @@
import abc import abc
import re
from collections import defaultdict
from urllib.parse import parse_qsl from urllib.parse import parse_qsl
from django.conf import settings from django.conf import settings
@ -13,7 +11,6 @@ from rest_framework.response import Response
from accounts.const import AliasAccount from accounts.const import AliasAccount
from assets.api import SerializeToTreeNodeMixin from assets.api import SerializeToTreeNodeMixin
from assets.const import AllTypes
from assets.models import Asset from assets.models import Asset
from assets.utils import KubernetesTree from assets.utils import KubernetesTree
from authentication.models import ConnectionToken from authentication.models import ConnectionToken
@ -38,21 +35,36 @@ class BaseUserNodeWithAssetAsTreeApi(
SelfOrPKUserMixin, RebuildTreeMixin, SelfOrPKUserMixin, RebuildTreeMixin,
SerializeToTreeNodeMixin, ListAPIView SerializeToTreeNodeMixin, ListAPIView
): ):
page_limit = 10000
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
nodes, assets = self.get_nodes_assets() offset = int(request.query_params.get('offset', 0))
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True) page_assets = self.get_page_assets()
tree_assets = self.serialize_assets(assets, node_key=self.node_key_for_serialize_assets)
data = list(tree_nodes) + list(tree_assets) if not offset:
return Response(data=data) nodes, assets = self.get_nodes_assets()
page = page_assets[:self.page_limit]
assets = [*assets, *page]
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, **self.serialize_asset_kwargs)
data = list(tree_nodes) + list(tree_assets)
else:
page = page_assets[offset:(offset + self.page_limit)]
data = self.serialize_assets(page, **self.serialize_asset_kwargs) if page else []
offset += len(page)
headers = {'X-JMS-TREE-OFFSET': offset} if offset else {}
return Response(data=data, headers=headers)
@abc.abstractmethod @abc.abstractmethod
def get_nodes_assets(self): def get_nodes_assets(self):
return [], [] return [], []
@lazyproperty def get_page_assets(self):
def node_key_for_serialize_assets(self): return []
return None
@property
def serialize_asset_kwargs(self):
return {}
class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi): class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
@ -61,7 +73,6 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
def get_nodes_assets(self): def get_nodes_assets(self):
self.query_node_util = UserPermNodeUtil(self.request.user) self.query_node_util = UserPermNodeUtil(self.request.user)
self.query_asset_util = UserPermAssetUtil(self.request.user)
ung_nodes, ung_assets = self._get_nodes_assets_for_ungrouped() ung_nodes, ung_assets = self._get_nodes_assets_for_ungrouped()
fav_nodes, fav_assets = self._get_nodes_assets_for_favorite() fav_nodes, fav_assets = self._get_nodes_assets_for_favorite()
all_nodes, all_assets = self._get_nodes_assets_for_all() all_nodes, all_assets = self._get_nodes_assets_for_all()
@ -69,31 +80,37 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
assets = list(ung_assets) + list(fav_assets) + list(all_assets) assets = list(ung_assets) + list(fav_assets) + list(all_assets)
return nodes, assets return nodes, assets
def get_page_assets(self):
return self.query_asset_util.get_all_assets().annotate(parent_key=F('nodes__key'))
@timeit @timeit
def _get_nodes_assets_for_ungrouped(self): def _get_nodes_assets_for_ungrouped(self):
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return [], [] return [], []
node = self.query_node_util.get_ungrouped_node() node = self.query_node_util.get_ungrouped_node()
assets = self.query_asset_util.get_ungroup_assets() assets = self.query_asset_util.get_ungroup_assets()
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \ assets = assets.annotate(parent_key=Value(node.key, output_field=CharField()))
.prefetch_related('platform')
return [node], assets return [node], assets
@lazyproperty
def query_asset_util(self):
return UserPermAssetUtil(self.user)
@timeit @timeit
def _get_nodes_assets_for_favorite(self): def _get_nodes_assets_for_favorite(self):
node = self.query_node_util.get_favorite_node() node = self.query_node_util.get_favorite_node()
assets = self.query_asset_util.get_favorite_assets() assets = self.query_asset_util.get_favorite_assets()
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \ assets = assets.annotate(parent_key=Value(node.key, output_field=CharField()))
.prefetch_related('platform')
return [node], assets return [node], assets
@timeit
def _get_nodes_assets_for_all(self): def _get_nodes_assets_for_all(self):
nodes = self.query_node_util.get_whole_tree_nodes(with_special=False) nodes = self.query_node_util.get_whole_tree_nodes(with_special=False)
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
assets = self.query_asset_util.get_perm_nodes_assets() assets = self.query_asset_util.get_perm_nodes_assets()
else: else:
assets = self.query_asset_util.get_all_assets() assets = Asset.objects.none()
assets = assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform') assets = assets.annotate(parent_key=F('nodes__key'))
return nodes, assets return nodes, assets
@ -103,6 +120,7 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
# 默认展开的节点key # 默认展开的节点key
default_unfolded_node_key = None default_unfolded_node_key = None
@timeit
def get_nodes_assets(self): def get_nodes_assets(self):
query_node_util = UserPermNodeUtil(self.user) query_node_util = UserPermNodeUtil(self.user)
query_asset_util = UserPermAssetUtil(self.user) query_asset_util = UserPermAssetUtil(self.user)
@ -136,14 +154,14 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
node_key = getattr(node, 'key', None) node_key = getattr(node, 'key', None)
return node_key return node_key
@lazyproperty @property
def node_key_for_serialize_assets(self): def serialize_asset_kwargs(self):
return self.query_node_key or self.default_unfolded_node_key return {
'node_key': self.query_node_key or self.default_unfolded_node_key
}
class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi( class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(BaseUserNodeWithAssetAsTreeApi):
SelfOrPKUserMixin, SerializeToTreeNodeMixin, ListAPIView
):
@property @property
def is_sync(self): def is_sync(self):
sync = self.request.query_params.get('sync', 0) sync = self.request.query_params.get('sync', 0)
@ -151,66 +169,52 @@ class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(
@property @property
def tp(self): def tp(self):
return self.request.query_params.get('type')
def get_assets(self):
query_asset_util = UserPermAssetUtil(self.user)
node = PermNode.objects.filter(
granted_node_rels__user=self.user, parent_key='').first()
if node:
__, assets = query_asset_util.get_node_all_assets(node.id)
else:
assets = Asset.objects.none()
return assets
def to_tree_nodes(self, assets):
if not assets:
return []
assets = assets.annotate(tp=F('platform__type'))
asset_type_map = defaultdict(list)
for asset in assets:
asset_type_map[asset.tp].append(asset)
tp = self.tp
if tp:
assets = asset_type_map.get(tp, [])
if not assets:
return []
pid = f'ROOT_{str(assets[0].category).upper()}_{tp}'
return self.serialize_assets(assets, pid=pid)
params = self.request.query_params params = self.request.query_params
get_root = not list(filter(lambda x: params.get(x), ('type', 'n'))) return [params.get('category'), params.get('type')]
resource_platforms = assets.order_by('id').values_list('platform_id', flat=True)
node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=get_root)
pattern = re.compile(r'\(0\)?')
nodes = []
for node in node_all:
meta = node.get('meta', {})
if pattern.search(node['name']) or meta.get('type') == 'platform':
continue
_type = meta.get('_type')
if _type:
node['type'] = _type
meta.setdefault('data', {})
node['meta'] = meta
nodes.append(node)
if not self.is_sync: @lazyproperty
return nodes def query_asset_util(self):
return UserPermAssetUtil(self.user)
asset_nodes = [] @timeit
for node in nodes: def get_assets(self):
node['open'] = True return self.query_asset_util.get_all_assets()
tp = node.get('meta', {}).get('_type')
if not tp:
continue
assets = asset_type_map.get(tp, [])
asset_nodes += self.serialize_assets(assets, pid=node['id'])
return nodes + asset_nodes
def list(self, request, *args, **kwargs): def _get_tree_nodes_async(self):
assets = self.get_assets() if not self.tp or not all(self.tp):
nodes = self.to_tree_nodes(assets) nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user)
return Response(data=nodes) return nodes, []
category, tp = self.tp
assets = self.get_assets().filter(platform__type=tp, platform__category=category)
return [], assets
def _get_tree_nodes_sync(self):
if self.request.query_params.get('lv'):
return []
nodes = self.query_asset_util.get_type_nodes_tree()
return nodes, []
@property
def serialize_asset_kwargs(self):
return {
'get_pid': lambda asset, platform: 'ROOT_{}_{}'.format(platform.category.upper(), platform.type),
}
def serialize_nodes(self, nodes, with_asset_amount=False):
return nodes
def get_nodes_assets(self):
if self.is_sync:
return self._get_tree_nodes_sync()
else:
return self._get_tree_nodes_async()
def get_page_assets(self):
if self.is_sync:
return self.get_assets()
else:
return []
class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView): class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):

View File

@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _
from accounts.const import AliasAccount from accounts.const import AliasAccount
from accounts.models import Account from accounts.models import Account
from assets.models import Asset from assets.models import Asset
from common.utils import date_expired_default from common.utils import date_expired_default, lazyproperty
from common.utils.timezone import local_now from common.utils.timezone import local_now
from labels.mixins import LabeledMixin from labels.mixins import LabeledMixin
from orgs.mixins.models import JMSOrgBaseModel from orgs.mixins.models import JMSOrgBaseModel
@ -105,6 +105,22 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
return True return True
return False return False
@lazyproperty
def users_amount(self):
return self.users.count()
@lazyproperty
def user_groups_amount(self):
return self.user_groups.count()
@lazyproperty
def assets_amount(self):
return self.assets.count()
@lazyproperty
def nodes_amount(self):
return self.nodes.count()
def get_all_users(self): def get_all_users(self):
from users.models import User from users.models import User
user_ids = self.users.all().values_list('id', flat=True) user_ids = self.users.all().values_list('id', flat=True)
@ -143,11 +159,14 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
@classmethod @classmethod
def get_all_users_for_perms(cls, perm_ids, flat=False): def get_all_users_for_perms(cls, perm_ids, flat=False):
user_ids = cls.users.through.objects.filter(assetpermission_id__in=perm_ids) \ user_ids = cls.users.through.objects \
.filter(assetpermission_id__in=perm_ids) \
.values_list('user_id', flat=True).distinct() .values_list('user_id', flat=True).distinct()
group_ids = cls.user_groups.through.objects.filter(assetpermission_id__in=perm_ids) \ group_ids = cls.user_groups.through.objects \
.filter(assetpermission_id__in=perm_ids) \
.values_list('usergroup_id', flat=True).distinct() .values_list('usergroup_id', flat=True).distinct()
group_user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \ group_user_ids = User.groups.through.objects \
.filter(usergroup_id__in=group_ids) \
.values_list('user_id', flat=True).distinct() .values_list('user_id', flat=True).distinct()
user_ids = set(user_ids) | set(group_user_ids) user_ids = set(user_ids) | set(group_user_ids)
if flat: if flat:

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models import Q from django.db.models import Q, Count
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
@ -14,7 +14,7 @@ from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from perms.models import ActionChoices, AssetPermission from perms.models import ActionChoices, AssetPermission
from users.models import User, UserGroup from users.models import User, UserGroup
__all__ = ["AssetPermissionSerializer", "ActionChoicesField"] __all__ = ["AssetPermissionSerializer", "ActionChoicesField", "AssetPermissionListSerializer"]
class ActionChoicesField(BitChoicesField): class ActionChoicesField(BitChoicesField):
@ -142,8 +142,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
def perform_display_create(instance, **kwargs): def perform_display_create(instance, **kwargs):
# 用户 # 用户
users_to_set = User.objects.filter( users_to_set = User.objects.filter(
Q(name__in=kwargs.get("users_display")) Q(name__in=kwargs.get("users_display")) |
| Q(username__in=kwargs.get("users_display")) Q(username__in=kwargs.get("users_display"))
).distinct() ).distinct()
instance.users.add(*users_to_set) instance.users.add(*users_to_set)
# 用户组 # 用户组
@ -153,8 +153,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
instance.user_groups.add(*user_groups_to_set) instance.user_groups.add(*user_groups_to_set)
# 资产 # 资产
assets_to_set = Asset.objects.filter( assets_to_set = Asset.objects.filter(
Q(address__in=kwargs.get("assets_display")) Q(address__in=kwargs.get("assets_display")) |
| Q(name__in=kwargs.get("assets_display")) Q(name__in=kwargs.get("assets_display"))
).distinct() ).distinct()
instance.assets.add(*assets_to_set) instance.assets.add(*assets_to_set)
# 节点 # 节点
@ -180,3 +180,26 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
instance = super().create(validated_data) instance = super().create(validated_data)
self.perform_display_create(instance, **display) self.perform_display_create(instance, **display)
return instance return instance
class AssetPermissionListSerializer(AssetPermissionSerializer):
users_amount = serializers.IntegerField(read_only=True, label=_("Users amount"))
user_groups_amount = serializers.IntegerField(read_only=True, label=_("User groups amount"))
assets_amount = serializers.IntegerField(read_only=True, label=_("Assets amount"))
nodes_amount = serializers.IntegerField(read_only=True, label=_("Nodes amount"))
class Meta(AssetPermissionSerializer.Meta):
amount_fields = ["users_amount", "user_groups_amount", "assets_amount", "nodes_amount"]
remove_fields = {"users", "assets", "nodes", "user_groups"}
fields = list(set(AssetPermissionSerializer.Meta.fields + amount_fields) - remove_fields)
@classmethod
def setup_eager_loading(cls, queryset):
"""Perform necessary eager loading of data."""
queryset = queryset.annotate(
users_amount=Count("users"),
user_groups_amount=Count("user_groups"),
assets_amount=Count("assets"),
nodes_amount=Count("nodes"),
)
return queryset

View File

@ -3,15 +3,13 @@
from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save
from django.dispatch import receiver from django.dispatch import receiver
from users.models import User, UserGroup
from assets.models import Asset from assets.models import Asset
from common.utils import get_logger, get_object_or_none
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR
from common.exceptions import M2MReverseNotAllowed
from common.utils import get_logger, get_object_or_none
from perms.models import AssetPermission from perms.models import AssetPermission
from perms.utils import UserPermTreeExpireUtil from perms.utils import UserPermTreeExpireUtil
from users.models import User, UserGroup
logger = get_logger(__file__) logger = get_logger(__file__)
@ -38,7 +36,7 @@ def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs):
group = UserGroup.objects.get(id=list(group_ids)[0]) group = UserGroup.objects.get(id=list(group_ids)[0])
org_id = group.org_id org_id = group.org_id
has_group_perm = AssetPermission.user_groups.through.objects\ has_group_perm = AssetPermission.user_groups.through.objects \
.filter(usergroup_id__in=group_ids).exists() .filter(usergroup_id__in=group_ids).exists()
if not has_group_perm: if not has_group_perm:
return return
@ -115,6 +113,7 @@ def on_asset_permission_user_groups_changed(sender, instance, action, pk_set, re
def on_node_asset_change(action, instance, reverse, pk_set, **kwargs): def on_node_asset_change(action, instance, reverse, pk_set, **kwargs):
if not need_rebuild_mapping_node(action): if not need_rebuild_mapping_node(action):
return return
print("Asset node changed: ", action)
if reverse: if reverse:
asset_ids = pk_set asset_ids = pk_set
node_ids = [instance.id] node_ids = [instance.id]

View File

@ -1,8 +1,7 @@
from django.db.models import QuerySet from django.db.models import QuerySet
from assets.models import Node, Asset from assets.models import Node, Asset
from common.utils import get_logger from common.utils import get_logger, timeit
from perms.models import AssetPermission from perms.models import AssetPermission
logger = get_logger(__file__) logger = get_logger(__file__)
@ -13,6 +12,7 @@ __all__ = ['AssetPermissionUtil']
class AssetPermissionUtil(object): class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """ """ 资产授权相关的方法工具 """
@timeit
def get_permissions_for_user(self, user, with_group=True, flat=False): def get_permissions_for_user(self, user, with_group=True, flat=False):
""" 获取用户的授权规则 """ """ 获取用户的授权规则 """
perm_ids = set() perm_ids = set()

View File

@ -1,13 +1,22 @@
from django.conf import settings import json
from django.db.models import Q import re
from django.conf import settings
from django.core.cache import cache
from django.db.models import Q
from rest_framework.utils.encoders import JSONEncoder
from assets.const import AllTypes
from assets.models import FavoriteAsset, Asset from assets.models import FavoriteAsset, Asset
from common.utils.common import timeit from common.utils.common import timeit, get_logger
from orgs.utils import current_org, tmp_to_root_org
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil'] __all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
logger = get_logger(__name__)
class AssetPermissionPermAssetUtil: class AssetPermissionPermAssetUtil:
@ -16,29 +25,32 @@ class AssetPermissionPermAssetUtil:
def get_all_assets(self): def get_all_assets(self):
""" 获取所有授权的资产 """ """ 获取所有授权的资产 """
node_asset_ids = self.get_perm_nodes_assets(flat=True) node_assets = self.get_perm_nodes_assets()
direct_asset_ids = self.get_direct_assets(flat=True) direct_assets = self.get_direct_assets()
asset_ids = list(node_asset_ids) + list(direct_asset_ids) # 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
assets = Asset.objects.filter(id__in=asset_ids) return (node_assets | direct_assets).distinct()
return assets
@timeit
def get_perm_nodes_assets(self, flat=False): def get_perm_nodes_assets(self, flat=False):
""" 获取所有授权节点下的资产 """ """ 获取所有授权节点下的资产 """
from assets.models import Node from assets.models import Node
nodes = Node.objects.prefetch_related('granted_by_permissions').filter( nodes = Node.objects \
granted_by_permissions__in=self.perm_ids).only('id', 'key') .prefetch_related('granted_by_permissions') \
.filter(granted_by_permissions__in=self.perm_ids) \
.only('id', 'key')
assets = PermNode.get_nodes_all_assets(*nodes) assets = PermNode.get_nodes_all_assets(*nodes)
if flat: if flat:
return assets.values_list('id', flat=True) return set(assets.values_list('id', flat=True))
return assets return assets
@timeit
def get_direct_assets(self, flat=False): def get_direct_assets(self, flat=False):
""" 获取直接授权的资产 """ """ 获取直接授权的资产 """
assets = Asset.objects.order_by() \ assets = Asset.objects.order_by() \
.filter(granted_by_permissions__id__in=self.perm_ids) \ .filter(granted_by_permissions__id__in=self.perm_ids) \
.distinct() .distinct()
if flat: if flat:
return assets.values_list('id', flat=True) return set(assets.values_list('id', flat=True))
return assets return assets
@ -52,12 +64,62 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
def get_ungroup_assets(self): def get_ungroup_assets(self):
return self.get_direct_assets() return self.get_direct_assets()
@timeit
def get_favorite_assets(self): def get_favorite_assets(self):
assets = self.get_all_assets() assets = Asset.objects.all().valid()
asset_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True) asset_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True)
assets = assets.filter(id__in=list(asset_ids)) assets = assets.filter(id__in=list(asset_ids))
return assets return assets
def get_type_nodes_tree(self):
assets = self.get_all_assets()
resource_platforms = assets.order_by('id').values_list('platform_id', flat=True)
node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=True)
pattern = re.compile(r'\(0\)?')
nodes = []
for node in node_all:
meta = node.get('meta', {})
if pattern.search(node['name']) or meta.get('type') == 'platform':
continue
_type = meta.get('_type')
if _type:
node['type'] = _type
node['category'] = meta.get('category')
meta.setdefault('data', {})
node['meta'] = meta
nodes.append(node)
return nodes
@classmethod
def get_type_nodes_tree_or_cached(cls, user):
key = f'perms:type-nodes-tree:{user.id}:{current_org.id}'
nodes = cache.get(key)
if nodes is None:
nodes = cls(user).get_type_nodes_tree()
nodes_json = json.dumps(nodes, cls=JSONEncoder)
cache.set(key, nodes_json, 60 * 60 * 24)
else:
nodes = json.loads(nodes)
return nodes
def refresh_type_nodes_tree_cache(self):
logger.debug("Refresh type nodes tree cache")
key = f'perms:type-nodes-tree:{self.user.id}:{current_org.id}'
cache.delete(key)
def refresh_favorite_assets(self):
favor_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True)
favor_ids = set(favor_ids)
with tmp_to_root_org():
valid_ids = self.get_all_assets() \
.filter(id__in=favor_ids) \
.values_list('id', flat=True)
valid_ids = set(valid_ids)
invalid_ids = favor_ids - valid_ids
FavoriteAsset.objects.filter(user=self.user, asset_id__in=invalid_ids).delete()
def get_node_assets(self, key): def get_node_assets(self, key):
node = PermNode.objects.get(key=key) node = PermNode.objects.get(key=key)
node.compute_node_from_and_assets_amount(self.user) node.compute_node_from_and_assets_amount(self.user)
@ -134,7 +196,11 @@ class UserPermNodeUtil:
self.perm_ids = AssetPermissionUtil().get_permissions_for_user(self.user, flat=True) self.perm_ids = AssetPermissionUtil().get_permissions_for_user(self.user, flat=True)
def get_favorite_node(self): def get_favorite_node(self):
assets_amount = UserPermAssetUtil(self.user).get_favorite_assets().count() favor_ids = FavoriteAsset.objects \
.filter(user=self.user) \
.values_list('asset_id') \
.distinct()
assets_amount = Asset.objects.all().valid().filter(id__in=favor_ids).count()
return PermNode.get_favorite_node(assets_amount) return PermNode.get_favorite_node(assets_amount)
def get_ungrouped_node(self): def get_ungrouped_node(self):

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from django.conf import settings from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from django.db import transaction
from assets.models import Asset from assets.models import Asset
from assets.utils import NodeAssetsUtil from assets.utils import NodeAssetsUtil
from common.db.models import output_as_string from common.db.models import output_as_string
from common.decorators import on_transaction_commit from common.decorators import on_transaction_commit, merge_delay_run
from common.utils import get_logger from common.utils import get_logger
from common.utils.common import lazyproperty, timeit from common.utils.common import lazyproperty, timeit
from orgs.models import Organization from orgs.models import Organization
@ -23,6 +24,7 @@ from perms.models import (
PermNode PermNode
) )
from users.models import User from users.models import User
from . import UserPermAssetUtil
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
logger = get_logger(__name__) logger = get_logger(__name__)
@ -50,24 +52,74 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
def __init__(self, user): def __init__(self, user):
self.user = user self.user = user
self.orgs = self.user.orgs.distinct()
self.org_ids = [str(o.id) for o in self.orgs] @lazyproperty
def orgs(self):
return self.user.orgs.distinct()
@lazyproperty
def org_ids(self):
return [str(o.id) for o in self.orgs]
@lazyproperty @lazyproperty
def cache_key_user(self): def cache_key_user(self):
return self.get_cache_key(self.user.id) return self.get_cache_key(self.user.id)
@lazyproperty
def cache_key_time(self):
key = 'perms.user.node_tree.built_time.{}'.format(self.user.id)
return key
@timeit @timeit
def refresh_if_need(self, force=False): def refresh_if_need(self, force=False):
self._clean_user_perm_tree_for_legacy_org() built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
to_refresh_orgs = self.orgs if force else self._get_user_need_refresh_orgs() to_refresh_orgs = self.orgs if force else self._get_user_need_refresh_orgs()
if not to_refresh_orgs: if not to_refresh_orgs:
logger.info('Not have to refresh orgs') logger.info('Not have to refresh orgs')
return return
with UserGrantedTreeRebuildLock(self.user.id): logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets(users=(self.user,))
@timeit
def refresh_tree_manual(self):
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh just now, pass: {}'.format(built_just_now))
return
to_refresh_orgs = self._get_user_need_refresh_orgs()
if not to_refresh_orgs:
logger.info('Not have to refresh orgs for user: {}'.format(self.user))
return
self.perform_refresh_user_tree(to_refresh_orgs)
@timeit
def perform_refresh_user_tree(self, to_refresh_orgs):
# 再判断一次,毕竟构建树比较慢
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
self._clean_user_perm_tree_for_legacy_org()
ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
lock = UserGrantedTreeRebuildLock(self.user.id)
got = lock.acquire(blocking=False)
if not got:
logger.info('User perm tree rebuild lock not acquired, pass')
return
try:
for org in to_refresh_orgs: for org in to_refresh_orgs:
self._rebuild_user_perm_tree_for_org(org) self._rebuild_user_perm_tree_for_org(org)
self._mark_user_orgs_refresh_finished(to_refresh_orgs) self._mark_user_orgs_refresh_finished(to_refresh_orgs)
finally:
lock.release()
def _rebuild_user_perm_tree_for_org(self, org): def _rebuild_user_perm_tree_for_org(self, org):
with tmp_to_org(org): with tmp_to_org(org):
@ -75,7 +127,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
UserPermTreeBuildUtil(self.user).rebuild_user_perm_tree() UserPermTreeBuildUtil(self.user).rebuild_user_perm_tree()
end = time.time() end = time.time()
logger.info( logger.info(
'Refresh user [{user}] org [{org}] perm tree, user {use_time:.2f}s' 'Refresh user perm tree: [{user}] org [{org}] {use_time:.2f}s'
''.format(user=self.user, org=org, use_time=end - start) ''.format(user=self.user, org=org, use_time=end - start)
) )
@ -90,7 +142,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
cached_org_ids = self.client.smembers(self.cache_key_user) cached_org_ids = self.client.smembers(self.cache_key_user)
cached_org_ids = {oid.decode() for oid in cached_org_ids} cached_org_ids = {oid.decode() for oid in cached_org_ids}
to_refresh_org_ids = set(self.org_ids) - cached_org_ids to_refresh_org_ids = set(self.org_ids) - cached_org_ids
to_refresh_orgs = Organization.objects.filter(id__in=to_refresh_org_ids) to_refresh_orgs = list(Organization.objects.filter(id__in=to_refresh_org_ids))
logger.info(f'Need to refresh orgs: {to_refresh_orgs}') logger.info(f'Need to refresh orgs: {to_refresh_orgs}')
return to_refresh_orgs return to_refresh_orgs
@ -128,7 +180,8 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
self.expire_perm_tree_for_user_groups_orgs(group_ids, org_ids) self.expire_perm_tree_for_user_groups_orgs(group_ids, org_ids)
def expire_perm_tree_for_user_groups_orgs(self, group_ids, org_ids): def expire_perm_tree_for_user_groups_orgs(self, group_ids, org_ids):
user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \ user_ids = User.groups.through.objects \
.filter(usergroup_id__in=group_ids) \
.values_list('user_id', flat=True).distinct() .values_list('user_id', flat=True).distinct()
self.expire_perm_tree_for_users_orgs(user_ids, org_ids) self.expire_perm_tree_for_users_orgs(user_ids, org_ids)
@ -151,6 +204,21 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
logger.info('Expire all user perm tree') logger.info('Expire all user perm tree')
@merge_delay_run(ttl=20)
def refresh_user_orgs_perm_tree(user_orgs=()):
for user, orgs in user_orgs:
util = UserPermTreeRefreshUtil(user)
util.perform_refresh_user_tree(orgs)
@merge_delay_run(ttl=20)
def refresh_user_favorite_assets(users=()):
for user in users:
util = UserPermAssetUtil(user)
util.refresh_favorite_assets()
util.refresh_type_nodes_tree_cache()
class UserPermTreeBuildUtil(object): class UserPermTreeBuildUtil(object):
node_only_fields = ('id', 'key', 'parent_key', 'org_id') node_only_fields = ('id', 'key', 'parent_key', 'org_id')
@ -161,13 +229,14 @@ class UserPermTreeBuildUtil(object):
self._perm_nodes_key_node_mapper = {} self._perm_nodes_key_node_mapper = {}
def rebuild_user_perm_tree(self): def rebuild_user_perm_tree(self):
self.clean_user_perm_tree() with transaction.atomic():
if not self.user_perm_ids: self.clean_user_perm_tree()
logger.info('User({}) not have permissions'.format(self.user)) if not self.user_perm_ids:
return logger.info('User({}) not have permissions'.format(self.user))
self.compute_perm_nodes() return
self.compute_perm_nodes_asset_amount() self.compute_perm_nodes()
self.create_mapping_nodes() self.compute_perm_nodes_asset_amount()
self.create_mapping_nodes()
def clean_user_perm_tree(self): def clean_user_perm_tree(self):
UserAssetGrantedTreeNodeRelation.objects.filter(user=self.user).delete() UserAssetGrantedTreeNodeRelation.objects.filter(user=self.user).delete()

View File

@ -139,7 +139,7 @@ class RBACPermission(permissions.DjangoModelPermissions):
if isinstance(perms, str): if isinstance(perms, str):
perms = [perms] perms = [perms]
has = request.user.has_perms(perms) has = request.user.has_perms(perms)
logger.debug('View require perms: {}, result: {}'.format(perms, has)) logger.debug('Api require perms: {}, result: {}'.format(perms, has))
return has return has
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):

View File

@ -6,7 +6,7 @@ from rest_framework.response import Response
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from ..models import UserGroup, User from ..models import UserGroup, User
from ..serializers import UserGroupSerializer from ..serializers import UserGroupSerializer, UserGroupListSerializer
__all__ = ['UserGroupViewSet'] __all__ = ['UserGroupViewSet']
@ -15,7 +15,10 @@ class UserGroupViewSet(OrgBulkModelViewSet):
model = UserGroup model = UserGroup
filterset_fields = ("name",) filterset_fields = ("name",)
search_fields = filterset_fields search_fields = filterset_fields
serializer_class = UserGroupSerializer serializer_classes = {
'default': UserGroupSerializer,
'list': UserGroupListSerializer,
}
ordering = ('name',) ordering = ('name',)
rbac_perms = ( rbac_perms = (
("add_all_users", "users.add_usergroup"), ("add_all_users", "users.add_usergroup"),

View File

@ -2,6 +2,7 @@
# #
from django.db.models import Count from django.db.models import Count
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.serializers.fields import ObjectRelatedField from common.serializers.fields import ObjectRelatedField
from common.serializers.mixin import ResourceLabelsMixin from common.serializers.mixin import ResourceLabelsMixin
@ -10,7 +11,7 @@ from .. import utils
from ..models import User, UserGroup from ..models import User, UserGroup
__all__ = [ __all__ = [
'UserGroupSerializer', 'UserGroupSerializer', 'UserGroupListSerializer',
] ]
@ -29,7 +30,6 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
fields = fields_mini + fields_small + ['users', 'labels'] fields = fields_mini + fields_small + ['users', 'labels']
extra_kwargs = { extra_kwargs = {
'created_by': {'label': _('Created by'), 'read_only': True}, 'created_by': {'label': _('Created by'), 'read_only': True},
'users_amount': {'label': _('Users amount')},
'id': {'label': _('ID')}, 'id': {'label': _('ID')},
} }
@ -45,6 +45,17 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('users', 'labels', 'labels__label') \ queryset = queryset.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count('users')) .annotate(users_amount=Count('users'))
return queryset return queryset
class UserGroupListSerializer(UserGroupSerializer):
users_amount = serializers.IntegerField(label=_('Users amount'), read_only=True)
class Meta(UserGroupSerializer.Meta):
fields = list(set(UserGroupSerializer.Meta.fields + ['users_amount']) - {'users'})
extra_kwargs = {
**UserGroupSerializer.Meta.extra_kwargs,
'users_amount': {'label': _('Users amount')},
}

View File

@ -17,6 +17,7 @@ from resources.assets import AssetsGenerator, NodesGenerator, PlatformGenerator
from resources.users import UserGroupGenerator, UserGenerator from resources.users import UserGroupGenerator, UserGenerator
from resources.perms import AssetPermissionGenerator from resources.perms import AssetPermissionGenerator
from resources.terminal import CommandGenerator, SessionGenerator from resources.terminal import CommandGenerator, SessionGenerator
from resources.accounts import AccountGenerator
resource_generator_mapper = { resource_generator_mapper = {
'asset': AssetsGenerator, 'asset': AssetsGenerator,
@ -27,6 +28,7 @@ resource_generator_mapper = {
'asset_permission': AssetPermissionGenerator, 'asset_permission': AssetPermissionGenerator,
'command': CommandGenerator, 'command': CommandGenerator,
'session': SessionGenerator, 'session': SessionGenerator,
'account': AccountGenerator,
'all': None 'all': None
# 'stat': StatGenerator # 'stat': StatGenerator
} }
@ -45,6 +47,7 @@ def main():
parser.add_argument('-o', '--org', type=str, default='') parser.add_argument('-o', '--org', type=str, default='')
args = parser.parse_args() args = parser.parse_args()
resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org
resource = resource.lower().rstrip('s')
generator_cls = [] generator_cls = []
if resource == 'all': if resource == 'all':

View File

@ -0,0 +1,32 @@
import random
import forgery_py
from accounts.models import Account
from assets.models import Asset
from .base import FakeDataGenerator
class AccountGenerator(FakeDataGenerator):
resource = 'account'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.assets = list(list(Asset.objects.all()[:5000]))
def do_generate(self, batch, batch_size):
accounts = []
for i in batch:
asset = random.choice(self.assets)
name = forgery_py.internet.user_name(True) + '-' + str(i)
d = {
'username': name,
'name': name,
'asset': asset,
'secret': name,
'secret_type': 'password',
'is_active': True,
'privileged': False,
}
accounts.append(Account(**d))
Account.objects.bulk_create(accounts, ignore_conflicts=True)

View File

@ -48,7 +48,7 @@ class AssetsGenerator(FakeDataGenerator):
def pre_generate(self): def pre_generate(self):
self.node_ids = list(Node.objects.all().values_list('id', flat=True)) self.node_ids = list(Node.objects.all().values_list('id', flat=True))
self.platform_ids = list(Platform.objects.all().values_list('id', flat=True)) self.platform_ids = list(Platform.objects.filter(category='host').values_list('id', flat=True))
def set_assets_nodes(self, assets): def set_assets_nodes(self, assets):
for asset in assets: for asset in assets:
@ -72,6 +72,17 @@ class AssetsGenerator(FakeDataGenerator):
assets.append(Asset(**data)) assets.append(Asset(**data))
creates = Asset.objects.bulk_create(assets, ignore_conflicts=True) creates = Asset.objects.bulk_create(assets, ignore_conflicts=True)
self.set_assets_nodes(creates) self.set_assets_nodes(creates)
self.set_asset_platform(creates)
@staticmethod
def set_asset_platform(assets):
protocol = random.choice(['ssh', 'rdp', 'telnet', 'vnc'])
protocols = []
for asset in assets:
port = 22 if protocol == 'ssh' else 3389
protocols.append(Protocol(asset=asset, name=protocol, port=port))
Protocol.objects.bulk_create(protocols, ignore_conflicts=True)
def after_generate(self): def after_generate(self):
pass pass

View File

@ -41,7 +41,7 @@ class FakeDataGenerator:
start = time.time() start = time.time()
self.do_generate(batch, self.batch_size) self.do_generate(batch, self.batch_size)
end = time.time() end = time.time()
using = end - start using = round(end - start, 3)
from_size = created from_size = created
created += len(batch) created += len(batch)
print('Generate %s: %s-%s [%s]' % (self.resource, from_size, created, using)) print('Generate %s: %s-%s [%s]' % (self.resource, from_size, created, using))

View File

@ -1,9 +1,11 @@
from random import choice, sample from random import sample
import forgery_py import forgery_py
from .base import FakeDataGenerator from orgs.utils import current_org
from rbac.models import RoleBinding, Role
from users.models import * from users.models import *
from .base import FakeDataGenerator
class UserGroupGenerator(FakeDataGenerator): class UserGroupGenerator(FakeDataGenerator):
@ -47,3 +49,12 @@ class UserGenerator(FakeDataGenerator):
users.append(u) users.append(u)
users = User.objects.bulk_create(users, ignore_conflicts=True) users = User.objects.bulk_create(users, ignore_conflicts=True)
self.set_groups(users) self.set_groups(users)
self.set_to_org(users)
def set_to_org(self, users):
bindings = []
role = Role.objects.get(name='OrgUser')
for u in users:
b = RoleBinding(user=u, role=role, org_id=current_org.id, scope='org')
bindings.append(b)
RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True)