From 9be3cbb936e530ccb5220d5edd3e6d61ca3959d9 Mon Sep 17 00:00:00 2001 From: xinwen Date: Mon, 8 Feb 2021 14:59:20 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E8=AF=A6=E6=83=85=E9=A1=B5=E6=8E=88=E6=9D=83=E5=88=97=E8=A1=A8?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E9=80=9F=E5=BA=A6&=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=8F=AF=E9=87=8D=E5=85=A5=E9=94=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/assets/api/asset.py | 2 - apps/assets/api/mixin.py | 4 +- apps/assets/api/node.py | 8 +- apps/assets/locks.py | 5 +- ...s_amount.py => 0066_auto_20210208_1802.py} | 8 +- apps/assets/models/asset.py | 2 +- apps/assets/models/node.py | 6 +- apps/assets/pagination.py | 48 ++++-- apps/assets/signals_handler/__init__.py | 3 +- .../signals_handler/node_assets_amount.py | 159 ++++++++++++++++++ ...n_nodes_tree.py => node_assets_mapping.py} | 0 apps/assets/tasks/nodes_amount.py | 33 ++++ apps/assets/utils.py | 37 +++- apps/common/utils/lock.py | 141 +++++++++++++--- apps/orgs/utils.py | 10 +- apps/perms/api/asset/asset_permission.py | 6 - .../user_permission_nodes_with_assets.py | 15 +- apps/perms/api/base.py | 2 +- ...204_1749.py => 0018_auto_20210208_1515.py} | 4 +- apps/perms/pagination.py | 36 +--- apps/perms/serializers/asset/permission.py | 11 +- apps/perms/utils/asset/user_permission.py | 18 +- 22 files changed, 434 insertions(+), 124 deletions(-) rename apps/assets/migrations/{0066_remove_node_assets_amount.py => 0066_auto_20210208_1802.py} (50%) create mode 100644 apps/assets/signals_handler/node_assets_amount.py rename apps/assets/signals_handler/{maintain_nodes_tree.py => node_assets_mapping.py} (100%) rename apps/perms/migrations/{0018_auto_20210204_1749.py => 0018_auto_20210208_1515.py} (96%) diff --git a/apps/assets/api/asset.py b/apps/assets/api/asset.py index 9d8a6bf89..2176f97aa 100644 --- a/apps/assets/api/asset.py +++ b/apps/assets/api/asset.py @@ -4,9 +4,7 @@ from assets.api import FilterAssetByNodeMixin from rest_framework.viewsets import ModelViewSet from rest_framework.generics import RetrieveAPIView from django.shortcuts import get_object_or_404 -from django.utils.decorators import method_decorator -from assets.locks import NodeTreeUpdateLock from common.utils import get_logger, get_object_or_none from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser from orgs.mixins.api import OrgBulkModelViewSet diff --git a/apps/assets/api/mixin.py b/apps/assets/api/mixin.py index 763f025f4..f7738f3f1 100644 --- a/apps/assets/api/mixin.py +++ b/apps/assets/api/mixin.py @@ -2,7 +2,7 @@ from typing import List from common.utils.common import timeit from assets.models import Node, Asset -from assets.pagination import AssetLimitOffsetPagination +from assets.pagination import NodeAssetTreePagination from common.utils import lazyproperty from assets.utils import get_node, is_query_node_all_assets @@ -81,7 +81,7 @@ class SerializeToTreeNodeMixin: class FilterAssetByNodeMixin: - pagination_class = AssetLimitOffsetPagination + pagination_class = NodeAssetTreePagination @lazyproperty def is_query_node_all_assets(self): diff --git a/apps/assets/api/node.py b/apps/assets/api/node.py index c793ad384..4ffdbfe1a 100644 --- a/apps/assets/api/node.py +++ b/apps/assets/api/node.py @@ -8,7 +8,6 @@ from rest_framework.response import Response from rest_framework.decorators import action from django.utils.translation import ugettext_lazy as _ from django.shortcuts import get_object_or_404, Http404 -from django.utils.decorators import method_decorator from django.db.models.signals import m2m_changed from common.const.http import POST @@ -25,10 +24,10 @@ from ..models import Node from ..tasks import ( update_node_assets_hardware_info_manual, test_node_assets_connectivity_manual, + check_node_assets_amount_task ) from .. import serializers from .mixin import SerializeToTreeNodeMixin -from assets.locks import NodeTreeUpdateLock logger = get_logger(__file__) @@ -54,6 +53,11 @@ class NodeViewSet(OrgModelViewSet): serializer.validated_data["key"] = child_key serializer.save() + @action(methods=[POST], detail=False, url_path='check_assets_amount_task') + def check_assets_amount_task(self, request): + task = check_node_assets_amount_task.delay(current_org.id) + return Response(data={'task': task.id}) + def perform_update(self, serializer): node = self.get_object() if node.is_org_root() and node.value != serializer.validated_data['value']: diff --git a/apps/assets/locks.py b/apps/assets/locks.py index b80db3ff8..bdab57080 100644 --- a/apps/assets/locks.py +++ b/apps/assets/locks.py @@ -15,7 +15,6 @@ class NodeTreeUpdateLock(DistributedLock): ) return name - def __init__(self, blocking=True): + def __init__(self): name = self.get_name() - super().__init__(name=name, blocking=blocking, - release_lock_on_transaction_commit=True) + super().__init__(name=name, release_on_transaction_commit=True, reentrant=True) diff --git a/apps/assets/migrations/0066_remove_node_assets_amount.py b/apps/assets/migrations/0066_auto_20210208_1802.py similarity index 50% rename from apps/assets/migrations/0066_remove_node_assets_amount.py rename to apps/assets/migrations/0066_auto_20210208_1802.py index 5d7044179..ffe7d8fb5 100644 --- a/apps/assets/migrations/0066_remove_node_assets_amount.py +++ b/apps/assets/migrations/0066_auto_20210208_1802.py @@ -1,4 +1,4 @@ -# Generated by Django 3.1 on 2021-02-04 09:49 +# Generated by Django 3.1 on 2021-02-08 10:02 from django.db import migrations @@ -10,8 +10,8 @@ class Migration(migrations.Migration): ] operations = [ - migrations.RemoveField( - model_name='node', - name='assets_amount', + migrations.AlterModelOptions( + name='asset', + options={'ordering': ['hostname'], 'verbose_name': 'Asset'}, ), ] diff --git a/apps/assets/models/asset.py b/apps/assets/models/asset.py index 5d133f40d..a21778a42 100644 --- a/apps/assets/models/asset.py +++ b/apps/assets/models/asset.py @@ -353,4 +353,4 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin): class Meta: unique_together = [('org_id', 'hostname')] verbose_name = _("Asset") - ordering = ["hostname", "ip"] + ordering = ["hostname", ] diff --git a/apps/assets/models/node.py b/apps/assets/models/node.py index e5b53eb45..85d857a7c 100644 --- a/apps/assets/models/node.py +++ b/apps/assets/models/node.py @@ -425,11 +425,6 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin): node_ids.update(_ids) return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct() - @property - def assets_amount(self): - assets_id = self.get_all_assets_id() - return len(assets_id) - def get_all_assets_id(self): assets_id = self.get_all_assets_id_by_node_key(org_id=self.org_id, node_key=self.key) return set(assets_id) @@ -550,6 +545,7 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin): date_create = models.DateTimeField(auto_now_add=True) parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"), db_index=True, default='') + assets_amount = models.IntegerField(default=0) objects = OrgManager.from_queryset(NodeQuerySet)() is_node = True diff --git a/apps/assets/pagination.py b/apps/assets/pagination.py index 4fd866e3d..7a55c1306 100644 --- a/apps/assets/pagination.py +++ b/apps/assets/pagination.py @@ -1,39 +1,51 @@ from rest_framework.pagination import LimitOffsetPagination from rest_framework.request import Request +from common.utils import get_logger from assets.models import Node +logger = get_logger(__name__) + + +class AssetPaginationBase(LimitOffsetPagination): + + def init_attrs(self, queryset, request: Request, view=None): + self._request = request + self._view = view + self._user = request.user + + def paginate_queryset(self, queryset, request: Request, view=None): + self.init_attrs(queryset, request, view) + return super().paginate_queryset(queryset, request, view=None) -class AssetLimitOffsetPagination(LimitOffsetPagination): - """ - 需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用 - """ def get_count(self, queryset): - """ - 1. 如果查询节点下的所有资产,那 count 使用 Node.assets_amount - 2. 如果有其他过滤条件使用 super - 3. 如果只查询该节点下的资产使用 super - """ exclude_query_params = { self.limit_query_param, self.offset_query_param, - 'node', 'all', 'show_current_asset', - 'node_id', 'display', 'draw', 'fields_size', + 'key', 'all', 'show_current_asset', + 'cache_policy', 'display', 'draw', + 'order', 'node', 'node_id', 'fields_size', } - for k, v in self._request.query_params.items(): if k not in exclude_query_params and v is not None: + logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}') return super().get_count(queryset) + node_assets_count = self.get_count_from_nodes(queryset) + if node_assets_count is None: + return super().get_count(queryset) + return node_assets_count + def get_count_from_nodes(self, queryset): + raise NotImplementedError + + +class NodeAssetTreePagination(AssetPaginationBase): + def get_count_from_nodes(self, queryset): is_query_all = self._view.is_query_node_all_assets if is_query_all: node = self._view.node if not node: node = Node.org_root() + logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}') return node.assets_amount - return super().get_count(queryset) - - def paginate_queryset(self, queryset, request: Request, view=None): - self._request = request - self._view = view - return super().paginate_queryset(queryset, request, view=None) + return None diff --git a/apps/assets/signals_handler/__init__.py b/apps/assets/signals_handler/__init__.py index 0c3980565..c8f332f26 100644 --- a/apps/assets/signals_handler/__init__.py +++ b/apps/assets/signals_handler/__init__.py @@ -1,2 +1,3 @@ from .common import * -from .maintain_nodes_tree import * +from .node_assets_amount import * +from .node_assets_mapping import * diff --git a/apps/assets/signals_handler/node_assets_amount.py b/apps/assets/signals_handler/node_assets_amount.py new file mode 100644 index 000000000..4501d0226 --- /dev/null +++ b/apps/assets/signals_handler/node_assets_amount.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# +from operator import add, sub +from django.db.models import Q, F +from django.dispatch import receiver +from django.db.models.signals import ( + m2m_changed +) + +from orgs.utils import ensure_in_real_or_default_org +from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR +from common.utils import get_logger +from assets.models import Asset, Node, compute_parent_key +from assets.locks import NodeTreeUpdateLock + + +logger = get_logger(__file__) + + +@receiver(m2m_changed, sender=Asset.nodes.through) +def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs): + # 不允许 `pre_clear` ,因为该信号没有 `pk_set` + # [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed) + refused = (PRE_CLEAR,) + if action in refused: + raise ValueError + + mapper = { + PRE_ADD: add, + POST_REMOVE: sub + } + if action not in mapper: + return + + operator = mapper[action] + + if reverse: + node: Node = instance + asset_pk_set = set(pk_set) + NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator) + else: + asset_pk = instance.id + # 与资产直接关联的节点 + node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True)) + NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator) + + +class NodeAssetsAmountUtils: + + @classmethod + def _remove_ancestor_keys(cls, ancestor_key, tree_set): + # 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环 + # 判断是否在集合里,来区分是否已被处理过 + while ancestor_key and ancestor_key in tree_set: + tree_set.remove(ancestor_key) + ancestor_key = compute_parent_key(ancestor_key) + + @classmethod + def _is_asset_exists_in_node(cls, asset_pk, node_key): + exists = Asset.objects.filter( + Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key) + ).filter(id=asset_pk).exists() + return exists + + @classmethod + @ensure_in_real_or_default_org + @NodeTreeUpdateLock() + def update_nodes_asset_amount(cls, node_keys, asset_pk, operator): + """ + 一个资产与多个节点关系变化时,更新计数 + + :param node_keys: 节点 id 的集合 + :param asset_pk: 资产 id + :param operator: 操作 + """ + + # 所有相关节点的祖先节点,组成一棵局部树 + ancestor_keys = set() + for key in node_keys: + ancestor_keys.update(Node.get_node_ancestor_keys(key)) + + # 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉 + node_keys -= ancestor_keys + + to_update_keys = [] + for key in node_keys: + # 遍历相关节点,处理它及其祖先节点 + # 查询该节点是否包含待处理资产 + exists = cls._is_asset_exists_in_node(asset_pk, key) + parent_key = compute_parent_key(key) + + if exists: + # 如果资产在该节点,那么他及其祖先节点都不用处理 + cls._remove_ancestor_keys(parent_key, ancestor_keys) + continue + else: + # 不存在,要更新本节点 + to_update_keys.append(key) + # 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环 + # 判断是否在集合里,来区分是否已被处理过 + while parent_key and parent_key in ancestor_keys: + exists = cls._is_asset_exists_in_node(asset_pk, parent_key) + if exists: + cls._remove_ancestor_keys(parent_key, ancestor_keys) + break + else: + to_update_keys.append(parent_key) + ancestor_keys.remove(parent_key) + parent_key = compute_parent_key(parent_key) + + Node.objects.filter(key__in=to_update_keys).update( + assets_amount=operator(F('assets_amount'), 1) + ) + + @classmethod + @ensure_in_real_or_default_org + @NodeTreeUpdateLock() + def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add): + """ + 一个节点与多个资产关系变化时,更新计数 + + :param node: 节点实例 + :param asset_pk_set: 资产的`id`集合, 内部不会修改该值 + :param operator: 操作 + * -> Node + # -> Asset + + * [3] + / \ + * * [2] + / \ + * * [1] + / / \ + * [a] # # [b] + + """ + # 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key` + ancestor_keys = node.get_ancestor_keys(with_self=True) + ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key') + to_update = [] + for ancestor in ancestors: + # 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3] + # 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的 + # 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量 + + asset_pk_set -= set(Asset.objects.filter( + id__in=asset_pk_set + ).filter( + Q(nodes__key__istartswith=f'{ancestor.key}:') | + Q(nodes__key=ancestor.key) + ).distinct().values_list('id', flat=True)) + if not asset_pk_set: + # 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量 + # 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用 + # 处理了 + break + ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set)) + to_update.append(ancestor) + Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key')) diff --git a/apps/assets/signals_handler/maintain_nodes_tree.py b/apps/assets/signals_handler/node_assets_mapping.py similarity index 100% rename from apps/assets/signals_handler/maintain_nodes_tree.py rename to apps/assets/signals_handler/node_assets_mapping.py diff --git a/apps/assets/tasks/nodes_amount.py b/apps/assets/tasks/nodes_amount.py index e69de29bb..a7cb46a45 100644 --- a/apps/assets/tasks/nodes_amount.py +++ b/apps/assets/tasks/nodes_amount.py @@ -0,0 +1,33 @@ +from celery import shared_task +from django.utils.translation import gettext_lazy as _ + +from orgs.models import Organization +from orgs.utils import tmp_to_org +from ops.celery.decorator import register_as_period_task +from assets.utils import check_node_assets_amount + +from common.utils.lock import AcquireFailed +from common.utils import get_logger + +logger = get_logger(__file__) + + +@shared_task +def check_node_assets_amount_task(orgid=None): + if orgid is None: + orgs = [*Organization.objects.all(), Organization.default()] + else: + orgs = [Organization.get_instance(orgid)] + + for org in orgs: + try: + with tmp_to_org(org): + check_node_assets_amount() + except AcquireFailed: + logger.error(_('The task of self-checking is already running and cannot be started repeatedly')) + + +@register_as_period_task(crontab='0 2 * * *') +@shared_task +def check_node_assets_amount_period_task(): + check_node_assets_amount_task() diff --git a/apps/assets/utils.py b/apps/assets/utils.py index 343fa704b..c9857f802 100644 --- a/apps/assets/utils.py +++ b/apps/assets/utils.py @@ -5,12 +5,45 @@ from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, from common.http import is_true from common.struct import Stack from common.db.models import output_as_string +from orgs.utils import ensure_in_real_or_default_org, current_org -from .models import Node +from .locks import NodeTreeUpdateLock +from .models import Node, Asset logger = get_logger(__file__) +@NodeTreeUpdateLock() +@ensure_in_real_or_default_org +def check_node_assets_amount(): + logger.info(f'Check node assets amount {current_org}') + nodes = list(Node.objects.all().only('id', 'key', 'assets_amount')) + nodeid_assetid_pairs = list(Asset.nodes.through.objects.all().values_list('node_id', 'asset_id')) + + nodekey_assetids_mapper = defaultdict(set) + nodeid_nodekey_mapper = {} + for node in nodes: + nodeid_nodekey_mapper[node.id] = node.key + + for nodeid, assetid in nodeid_assetid_pairs: + if nodeid not in nodeid_nodekey_mapper: + continue + nodekey = nodeid_nodekey_mapper[nodeid] + nodekey_assetids_mapper[nodekey].add(assetid) + + util = NodeAssetsUtil(nodes, nodekey_assetids_mapper) + util.generate() + + to_updates = [] + for node in nodes: + assets_amount = util.get_assets_amount(node.key) + if node.assets_amount != assets_amount: + logger.error(f'Node[{node.key}] assets amount error {node.assets_amount} != {assets_amount}') + node.assets_amount = assets_amount + to_updates.append(node) + Node.objects.bulk_update(to_updates, fields=('assets_amount',)) + + def is_query_node_all_assets(request): request = request query_all_arg = request.query_params.get('all', 'true') @@ -104,5 +137,3 @@ class NodeAssetsUtil: util = cls(nodes, mapping) util.generate() return util - - diff --git a/apps/common/utils/lock.py b/apps/common/utils/lock.py index d7d7acbed..1a016d3d4 100644 --- a/apps/common/utils/lock.py +++ b/apps/common/utils/lock.py @@ -8,6 +8,7 @@ from django.db import transaction from common.utils import get_logger from common.utils.inspect import copy_function_args from apps.jumpserver.const import CONFIG +from common.local import thread_local logger = get_logger(__file__) @@ -16,24 +17,28 @@ class AcquireFailed(RuntimeError): pass +class LockHasTimeOut(RuntimeError): + pass + + class DistributedLock(RedisLock): - def __init__(self, name, blocking=True, expire=None, release_lock_on_transaction_commit=False, - release_raise_exc=False, auto_renewal_seconds=60*2): + def __init__(self, name, *, expire=None, release_on_transaction_commit=False, + reentrant=False, release_raise_exc=False, auto_renewal_seconds=60): """ 使用 redis 构造的分布式锁 :param name: 锁的名字,要全局唯一 - :param blocking: - 该参数只在锁作为装饰器或者 `with` 时有效。 :param expire: 锁的过期时间 - :param release_lock_on_transaction_commit: + :param release_on_transaction_commit: 是否在当前事务结束后再释放锁 :param release_raise_exc: 释放锁时,如果没有持有锁是否抛异常或静默 :param auto_renewal_seconds: 当持有一个无限期锁的时候,刷新锁的时间,具体参考 `redis_lock.Lock#auto_renewal` + :param reentrant: + 是否可重入 """ self.kwargs_copy = copy_function_args(self.__init__, locals()) redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD) @@ -45,28 +50,20 @@ class DistributedLock(RedisLock): auto_renewal = False super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal) - self._blocking = blocking - self._release_lock_on_transaction_commit = release_lock_on_transaction_commit + self._release_on_transaction_commit = release_on_transaction_commit self._release_raise_exc = release_raise_exc + self._reentrant = reentrant + self._acquired_reentrant_lock = False + self._thread_id = threading.current_thread().ident def __enter__(self): - thread_id = threading.current_thread().ident - logger.debug(f'Attempt to acquire global lock: thread {thread_id} lock {self._name}') - acquired = self.acquire(blocking=self._blocking) - if self._blocking and not acquired: - logger.debug(f'Not acquired lock, but blocking=True, thread {thread_id} lock {self._name}') - raise EnvironmentError("Lock wasn't acquired, but blocking=True") + acquired = self.acquire(blocking=True) if not acquired: - logger.debug(f'Not acquired the lock, thread {thread_id} lock {self._name}') raise AcquireFailed - logger.debug(f'Acquire lock success, thread {thread_id} lock {self._name}') return self def __exit__(self, exc_type=None, exc_value=None, traceback=None): - if self._release_lock_on_transaction_commit: - transaction.on_commit(self.release) - else: - self.release() + self.release() def __call__(self, func): @wraps(func) @@ -82,9 +79,105 @@ class DistributedLock(RedisLock): return True return False - def release(self): + def locked_by_current_thread(self): + if self.locked(): + owner_id = self.get_owner_id() + local_owner_id = getattr(thread_local, self.name, None) + + if local_owner_id and owner_id == local_owner_id: + return True + return False + + def acquire(self, blocking=True, timeout=None): + if self._reentrant: + if self.locked_by_current_thread(): + self._acquired_reentrant_lock = True + logger.debug( + f'I[{self.id}] reentry lock[{self.name}] in thread[{self._thread_id}].') + return True + + logger.debug(f'I[{self.id}] attempt acquire reentrant-lock[{self.name}].') + acquired = super().acquire(blocking=blocking, timeout=timeout) + if acquired: + logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] now.') + setattr(thread_local, self.name, self.id) + else: + logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] failed.') + return acquired + else: + logger.debug(f'I[{self.id}] attempt acquire lock[{self.name}].') + acquired = super().acquire(blocking=blocking, timeout=timeout) + logger.debug(f'I[{self.id}] acquired lock[{self.name}] {acquired}.') + return acquired + + @property + def name(self): + return self._name + + def _raise_exc_with_log(self, msg, *, exc_cls=NotAcquired): + e = exc_cls(msg) + logger.error(msg) + self._raise_exc(e) + + def _raise_exc(self, e): + if self._release_raise_exc: + raise e + + def _release_on_reentrant_locked_by_brother(self): + if self._acquired_reentrant_lock: + self._acquired_reentrant_lock = False + logger.debug(f'I[{self.id}] released reentrant-lock[{self.name}] owner[{self.get_owner_id()}] in thread[{self._thread_id}]') + return + else: + self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired by me[{self.id}].') + + def _release_on_reentrant_locked_by_me(self): + logger.debug(f'I[{self.id}] release reentrant-lock[{self.name}] in thread[{self._thread_id}]') + + id = getattr(thread_local, self.name, None) + if id != self.id: + raise PermissionError(f'Reentrant-lock[{self.name}] is not locked by me[{self.id}], owner[{id}]') try: - super().release() - except AcquireFailed as e: - if self._release_raise_exc: - raise e + # 这里要保证先删除 thread_local 的标记, + delattr(thread_local, self.name) + except AttributeError: + pass + finally: + try: + # 这里处理的是边界情况, + # 判断锁是我的 -> 锁超时 -> 释放锁报错 + # 此时的报错应该被静默 + self._release_redis_lock() + except NotAcquired: + pass + + def _release_redis_lock(self): + # 最底层 api + super().release() + + def _release(self): + try: + self._release_redis_lock() + except NotAcquired as e: + logger.error(f'I[{self.id}] release lock[{self.name}] failed {e}') + self._raise_exc(e) + + def release(self): + _release = self._release + + # 处理可重入锁 + if self._reentrant: + if self.locked_by_current_thread(): + if self.locked_by_me(): + _release = self._release_on_reentrant_locked_by_me + else: + _release = self._release_on_reentrant_locked_by_brother + else: + self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired in current-thread[{self._thread_id}]') + + # 处理是否在事务提交时才释放锁 + if self._release_on_transaction_commit: + logger.debug(f'I[{self.id}] release lock[{self.name}] on transaction commit ...') + transaction.on_commit(_release) + else: + _release() diff --git a/apps/orgs/utils.py b/apps/orgs/utils.py index d01ae3f77..eef12b1c6 100644 --- a/apps/orgs/utils.py +++ b/apps/orgs/utils.py @@ -186,6 +186,10 @@ def org_aware_func(org_arg_name): current_org = LocalProxy(get_current_org) -def ensure_in_real_or_default_org(): - if not current_org or current_org.is_root(): - raise ValueError('You must in a real or default org!') +def ensure_in_real_or_default_org(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not current_org or current_org.is_root(): + raise ValueError('You must in a real or default org!') + return func(*args, **kwargs) + return wrapper diff --git a/apps/perms/api/asset/asset_permission.py b/apps/perms/api/asset/asset_permission.py index 5062f2099..e38a59ba6 100644 --- a/apps/perms/api/asset/asset_permission.py +++ b/apps/perms/api/asset/asset_permission.py @@ -26,12 +26,6 @@ class AssetPermissionViewSet(BasePermissionViewSet): 'node_id', 'node', 'asset_id', 'hostname', 'ip' ] - def get_queryset(self): - queryset = super().get_queryset().prefetch_related( - "nodes", "assets", "users", "user_groups", "system_users" - ) - return queryset - def filter_node(self, queryset): node_id = self.request.query_params.get('node_id') node_name = self.request.query_params.get('node') diff --git a/apps/perms/api/asset/user_permission/user_permission_nodes_with_assets.py b/apps/perms/api/asset/user_permission/user_permission_nodes_with_assets.py index 253a925ca..63619e9c1 100644 --- a/apps/perms/api/asset/user_permission/user_permission_nodes_with_assets.py +++ b/apps/perms/api/asset/user_permission/user_permission_nodes_with_assets.py @@ -14,7 +14,6 @@ from .mixin import RoleUserMixin, RoleAdminMixin from perms.utils.asset.user_permission import ( UserGrantedTreeBuildUtils, get_user_all_asset_perm_ids, UserGrantedNodesQueryUtils, UserGrantedAssetsQueryUtils, - QuerySetStage, ) from perms.models import AssetPermission, PermNode from assets.models import Asset @@ -44,10 +43,10 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView): def add_favorite_resource(self, data: list, nodes_query_utils, assets_query_utils): favorite_node = nodes_query_utils.get_favorite_node() - qs_state = QuerySetStage().annotate( + favorite_assets = assets_query_utils.get_favorite_assets() + favorite_assets = favorite_assets.annotate( parent_key=Value(favorite_node.key, output_field=CharField()) ).prefetch_related('platform') - favorite_assets = assets_query_utils.get_favorite_assets(qs_stage=qs_state, only=()) data.extend(self.serialize_nodes([favorite_node], with_asset_amount=True)) data.extend(self.serialize_assets(favorite_assets)) @@ -59,13 +58,11 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView): data.extend(self.serialize_nodes(nodes, with_asset_amount=True)) def add_assets(self, data: list, assets_query_utils: UserGrantedAssetsQueryUtils): - qs_stage = QuerySetStage().annotate(parent_key=F('nodes__key')).prefetch_related('platform') - if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: - all_assets = assets_query_utils.get_direct_granted_nodes_assets(qs_stage=qs_stage) + all_assets = assets_query_utils.get_direct_granted_nodes_assets() else: - all_assets = assets_query_utils.get_all_granted_assets(qs_stage=qs_stage) - + all_assets = assets_query_utils.get_all_granted_assets() + all_assets = all_assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform') data.extend(self.serialize_assets(all_assets)) @tmp_to_root_org() @@ -144,8 +141,6 @@ class GrantedNodeChildrenWithAssetsAsTreeApiMixin(SerializeToTreeNodeMixin, assets = assets_query_utils.get_node_assets(key) assets = assets.prefetch_related('platform') - user = self.user - tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True) tree_assets = self.serialize_assets(assets, key) return Response(data=[*tree_nodes, *tree_assets]) diff --git a/apps/perms/api/base.py b/apps/perms/api/base.py index 3d9c8672d..8c0028baf 100644 --- a/apps/perms/api/base.py +++ b/apps/perms/api/base.py @@ -45,7 +45,7 @@ class BasePermissionViewSet(OrgBulkModelViewSet): if not self.is_query_all(): queryset = queryset.filter(users=user) return queryset - groups = user.groups.all() + groups = list(user.groups.all().values_list('id', flat=True)) queryset = queryset.filter( Q(users=user) | Q(user_groups__in=groups) ).distinct() diff --git a/apps/perms/migrations/0018_auto_20210204_1749.py b/apps/perms/migrations/0018_auto_20210208_1515.py similarity index 96% rename from apps/perms/migrations/0018_auto_20210204_1749.py rename to apps/perms/migrations/0018_auto_20210208_1515.py index 00d567c2f..b5271f36f 100644 --- a/apps/perms/migrations/0018_auto_20210204_1749.py +++ b/apps/perms/migrations/0018_auto_20210208_1515.py @@ -1,4 +1,4 @@ -# Generated by Django 3.1 on 2021-02-04 09:49 +# Generated by Django 3.1 on 2021-02-08 07:15 import assets.models.node from django.conf import settings @@ -9,8 +9,8 @@ import django.db.models.deletion class Migration(migrations.Migration): dependencies = [ - ('assets', '0066_remove_node_assets_amount'), migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('assets', '0065_auto_20210121_1549'), ('perms', '0017_auto_20210104_0435'), ] diff --git a/apps/perms/pagination.py b/apps/perms/pagination.py index fc5e43de7..c740830c2 100644 --- a/apps/perms/pagination.py +++ b/apps/perms/pagination.py @@ -1,37 +1,17 @@ -from rest_framework.pagination import LimitOffsetPagination +from django.conf import settings from rest_framework.request import Request -from django.db.models import Sum +from assets.pagination import AssetPaginationBase from perms.models import UserAssetGrantedTreeNodeRelation from common.utils import get_logger logger = get_logger(__name__) -class GrantedAssetPaginationBase(LimitOffsetPagination): - - def paginate_queryset(self, queryset, request: Request, view=None): - self._request = request - self._view = view - self._user = request.user - return super().paginate_queryset(queryset, request, view=None) - - def get_count(self, queryset): - exclude_query_params = { - self.limit_query_param, - self.offset_query_param, - 'key', 'all', 'show_current_asset', - 'cache_policy', 'display', 'draw', - 'order', - } - for k, v in self._request.query_params.items(): - if k not in exclude_query_params and v is not None: - logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}') - return super().get_count(queryset) - return self.get_count_from_nodes(queryset) - - def get_count_from_nodes(self, queryset): - raise NotImplementedError +class GrantedAssetPaginationBase(AssetPaginationBase): + def init_attrs(self, queryset, request: Request, view=None): + super().init_attrs(queryset, request, view) + self._user = view.user class NodeGrantedAssetPagination(GrantedAssetPaginationBase): @@ -42,11 +22,13 @@ class NodeGrantedAssetPagination(GrantedAssetPaginationBase): return node.assets_amount else: logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} not has `pagination_node` -> {self._request.get_full_path()}') - return super().get_count(queryset) + return None class AllGrantedAssetPagination(GrantedAssetPaginationBase): def get_count_from_nodes(self, queryset): + if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: + return None assets_amount = sum(UserAssetGrantedTreeNodeRelation.objects.filter( user=self._user, node_parent_key='' ).values_list('node_assets_amount', flat=True)) diff --git a/apps/perms/serializers/asset/permission.py b/apps/perms/serializers/asset/permission.py index 475b83ee1..2bc723706 100644 --- a/apps/perms/serializers/asset/permission.py +++ b/apps/perms/serializers/asset/permission.py @@ -3,9 +3,12 @@ from rest_framework import serializers from django.utils.translation import ugettext_lazy as _ +from django.db.models import Prefetch from orgs.mixins.serializers import BulkOrgResourceModelSerializer from perms.models import AssetPermission, Action +from assets.models import Asset, Node, SystemUser +from users.models import User, UserGroup __all__ = [ 'AssetPermissionSerializer', @@ -68,5 +71,11 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer): @classmethod def setup_eager_loading(cls, queryset): """ Perform necessary eager loading of data. """ - queryset = queryset.prefetch_related('users', 'user_groups', 'assets', 'nodes', 'system_users') + queryset = queryset.prefetch_related( + Prefetch('system_users', queryset=SystemUser.objects.only('id')), + Prefetch('user_groups', queryset=UserGroup.objects.only('id')), + Prefetch('users', queryset=User.objects.only('id')), + Prefetch('assets', queryset=Asset.objects.only('id')), + Prefetch('nodes', queryset=Node.objects.only('id')) + ) return queryset diff --git a/apps/perms/utils/asset/user_permission.py b/apps/perms/utils/asset/user_permission.py index 6bd2e8b2c..7f8d53c7e 100644 --- a/apps/perms/utils/asset/user_permission.py +++ b/apps/perms/utils/asset/user_permission.py @@ -115,8 +115,8 @@ class UnionQuerySet(QuerySet): def __getitem__(self, item): return self.__execute()[item] - def __next__(self): - return next(self.__execute()) + def __iter__(self): + return iter(self.__execute()) @classmethod def test_it(cls): @@ -299,12 +299,12 @@ class UserGrantedTreeRefreshController: cls.remove_builed_orgs_from_users(orgs_id, users_id) @classmethod + @ensure_in_real_or_default_org def add_need_refresh_on_nodes_assets_relate_change(cls, node_ids, asset_ids): """ 1,计算与这些资产有关的授权 2,计算与这些节点以及祖先节点有关的授权 """ - ensure_in_real_or_default_org() node_ids = set(node_ids) ancestor_node_keys = set() @@ -340,8 +340,8 @@ class UserGrantedTreeRefreshController: cls.add_need_refresh_by_asset_perm_ids(perm_ids) @classmethod + @ensure_in_real_or_default_org def add_need_refresh_by_asset_perm_ids(cls, asset_perm_ids): - ensure_in_real_or_default_org() group_ids = AssetPermission.user_groups.through.objects.filter( assetpermission_id__in=asset_perm_ids @@ -429,8 +429,8 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase): return asset_ids @timeit + @ensure_in_real_or_default_org def rebuild_user_granted_tree(self): - ensure_in_real_or_default_org() logger.info(f'Rebuild user:{self.user} tree in org:{current_org}') user = self.user @@ -618,13 +618,13 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase): class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): - def get_favorite_assets(self, only=('id', )) -> QuerySet: + def get_favorite_assets(self) -> QuerySet: favorite_asset_ids = FavoriteAsset.objects.filter( user=self.user ).values_list('asset_id', flat=True) favorite_asset_ids = list(favorite_asset_ids) assets = self.get_all_granted_assets() - assets = assets.filter(id__in=favorite_asset_ids).only(*only) + assets = assets.filter(id__in=favorite_asset_ids) return assets def get_ungroup_assets(self) -> AssetQuerySet: @@ -670,7 +670,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): granted_status = node.get_granted_status(self.user) if granted_status == NodeFrom.granted: - assets = Asset.objects.order_by().filter(nodes_id=node.id) + assets = Asset.objects.order_by().filter(nodes__id=node.id) return assets elif granted_status == NodeFrom.asset: return self._get_indirect_granted_node_assets(node.id) @@ -678,7 +678,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): return Asset.objects.none() def _get_indirect_granted_node_assets(self, id) -> AssetQuerySet: - assets = Asset.objects.order_by().filter(nodes_id=id) & self.get_direct_granted_assets() + assets = Asset.objects.order_by().filter(nodes__id=id).distinct() & self.get_direct_granted_assets() return assets def _get_indirect_granted_node_all_assets(self, node) -> QuerySet: