From 7cf6e54f01e9691769ca4be2fe0b74f9d3a91b27 Mon Sep 17 00:00:00 2001 From: "Jiangjie.Bai" <32935519+BaiJiangJie@users.noreply.github.com> Date: Fri, 5 Feb 2021 13:29:29 +0800 Subject: [PATCH] =?UTF-8?q?refactor=20tree=20(=E9=87=8D=E6=9E=84&=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E8=B5=84=E4=BA=A7=E6=A0=91/=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E6=8E=88=E6=9D=83=E6=A0=91=E5=8A=A0=E8=BD=BD=E9=80=9F=E5=BA=A6?= =?UTF-8?q?)=20(#5548)=20(#5549)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bai reactor tree ( 重构获取完整资产树中节点下资产总数的逻辑) (#5548) * tree: v0.1 * tree: v0.2 * tree: v0.3 * tree: v0.4 * tree: 添加并发锁未请求到时的debug日志 * 以空间换时间的方式优化资产树 * Reactor tree togther v2 (#5576) * Bai reactor tree ( 重构获取完整资产树中节点下资产总数的逻辑) (#5548) * tree: v0.1 * tree: v0.2 * tree: v0.3 * tree: v0.4 * tree: 添加并发锁未请求到时的debug日志 * 以空间换时间的方式优化资产树 * 修改授权适配新方案 * 添加树处理工具 * 完成新的用户授权树计算以及修改一些信号 * 重构了获取资产的一些 api * 重构了一些节点的api * 整理了一些代码 * 完成了api 的重构 * 重构检查节点数量功能 * 完成重构授权树工具类 * api 添加强制刷新参数 * 整理一些信号 * 处理一些信号的问题 * 完成了信号的处理 * 重构了资产树相关的锁机制 * RebuildUserTreeTask 还得添加回来 * 优化下不能在root组织的检查函数 * 优化资产树变化时锁的使用 * 修改一些算法的小工具 * 资产树锁不再校验是否在具体组织里 * 整理了一些信号的位置 * 修复资产与节点关系维护的bug * 去掉一些调试代码 * 修复资产授权过期检查刷新授权树的 bug * 添加了可重入锁 * 添加一些计时,优化一些sql * 增加 union 查询的支持 * 尝试用 sql 解决节点资产数量问题 * 开始优化计算授权树节点资产数量不用冗余表 * 新代码能跑起来了,修复一下bug * 去掉 UserGrantedMappingNode 换成 UserAssetGrantedTreeNodeRelation * 修了些bug,做了些优化 * 优化QuerySetStage 执行逻辑 * 与小白的内存结合了 * 删掉老的表,迁移新的 assets_amount 字段 * 优化用户授权页面资产列表 count 慢 * 修复批量命令数量不对 * 修改获取非直接授权节点的 children 的逻辑 * 获取整棵树的节点 * 回退锁 * 整理迁移脚本 * 改变更新树策略 * perf: 修改一波缩进 * fix: 修改handler名称 * 修复授权树获取资产sql 泛滥 * 修复授权规则有效bug * 修复一些bug * 修复一些bug * 又修了一些小bug * 去掉了老的 get_nodes_all_assets * 修改一些写法 * Reactor tree togther b2 (#5570) * fix: 修改handler名称 * perf: 优化生成树 * perf: 去掉注释 * 优化了一些 * 重新生成迁移脚本 * 去掉周期检查节点资产数量的任务 * Pr@reactor tree togther guang@perf mapping (#5573) * fix: 修改handler名称 * perf: mapping 拆分出来 * 修改名称 * perf: 修改锁名 * perf: 去掉检查节点任务 * perf: 修改一下名称 * perf: 优化一波 Co-authored-by: Jiangjie.Bai <32935519+BaiJiangJie@users.noreply.github.com> Co-authored-by: Bai Co-authored-by: xinwen Co-authored-by: xinwen Co-authored-by: 老广 --- apps/assets/api/asset.py | 4 +- apps/assets/api/mixin.py | 4 + apps/assets/api/node.py | 21 +- apps/assets/locks.py | 21 + .../0066_remove_node_assets_amount.py | 17 + apps/assets/models/asset.py | 10 +- apps/assets/models/base.py | 1 + apps/assets/models/favorite_asset.py | 12 - apps/assets/models/node.py | 185 ++- apps/assets/models/user.py | 2 +- apps/assets/serializers/asset.py | 9 +- apps/assets/signals_handler/__init__.py | 2 + .../common.py} | 139 +- .../signals_handler/maintain_nodes_tree.py | 88 ++ apps/assets/tasks/common.py | 3 +- apps/assets/tasks/gather_asset_users.py | 3 +- apps/assets/tasks/nodes_amount.py | 27 - apps/assets/tests/tree.py | 33 + apps/assets/urls/api_urls.py | 7 +- apps/assets/utils.py | 113 +- apps/common/const/distributed_lock_key.py | 2 - apps/common/db/models.py | 4 + apps/common/utils/common.py | 19 + apps/common/utils/lock.py | 53 +- apps/orgs/lock.py | 131 -- apps/orgs/utils.py | 5 + .../api/application/user_permission/common.py | 6 +- .../user_permission_applications.py | 6 +- apps/perms/api/asset/user_permission/mixin.py | 36 +- .../user_permission/user_permission_assets.py | 156 --- .../user_permission_assets/__init__.py | 1 + .../user_permission_assets/mixin.py | 127 ++ .../user_permission_assets/views.py | 99 ++ .../user_permission/user_permission_nodes.py | 55 +- .../user_permission_nodes_with_assets.py | 184 +-- apps/perms/api/system_user_permission.py | 6 +- apps/perms/async_tasks/__init__.py | 0 apps/perms/async_tasks/mapping_node_task.py | 47 - apps/perms/locks.py | 11 + .../migrations/0014_build_users_perm_tree.py | 14 - .../migrations/0018_auto_20210204_1749.py | 65 + apps/perms/models/asset_permission.py | 123 +- apps/perms/pagination.py | 38 +- apps/perms/signals_handler/__init__.py | 2 + .../common.py} | 88 +- apps/perms/signals_handler/refresh_perms.py | 115 ++ apps/perms/tasks.py | 90 +- apps/perms/utils/asset/user_permission.py | 1122 ++++++++++------- utils/generate_fake_data/resources/assets.py | 3 +- 49 files changed, 1829 insertions(+), 1480 deletions(-) create mode 100644 apps/assets/locks.py create mode 100644 apps/assets/migrations/0066_remove_node_assets_amount.py create mode 100644 apps/assets/signals_handler/__init__.py rename apps/assets/{signals_handler.py => signals_handler/common.py} (60%) create mode 100644 apps/assets/signals_handler/maintain_nodes_tree.py create mode 100644 apps/assets/tests/tree.py delete mode 100644 apps/common/const/distributed_lock_key.py delete mode 100644 apps/orgs/lock.py delete mode 100644 apps/perms/api/asset/user_permission/user_permission_assets.py create mode 100644 apps/perms/api/asset/user_permission/user_permission_assets/__init__.py create mode 100644 apps/perms/api/asset/user_permission/user_permission_assets/mixin.py create mode 100644 apps/perms/api/asset/user_permission/user_permission_assets/views.py delete mode 100644 apps/perms/async_tasks/__init__.py delete mode 100644 apps/perms/async_tasks/mapping_node_task.py create mode 100644 apps/perms/locks.py create mode 100644 apps/perms/migrations/0018_auto_20210204_1749.py create mode 100644 apps/perms/signals_handler/__init__.py rename apps/perms/{signals_handler.py => signals_handler/common.py} (71%) create mode 100644 apps/perms/signals_handler/refresh_perms.py diff --git a/apps/assets/api/asset.py b/apps/assets/api/asset.py index 4de9fb899..9d8a6bf89 100644 --- a/apps/assets/api/asset.py +++ b/apps/assets/api/asset.py @@ -3,10 +3,10 @@ from assets.api import FilterAssetByNodeMixin from rest_framework.viewsets import ModelViewSet from rest_framework.generics import RetrieveAPIView -from rest_framework.response import Response -from rest_framework import status from django.shortcuts import get_object_or_404 +from django.utils.decorators import method_decorator +from assets.locks import NodeTreeUpdateLock from common.utils import get_logger, get_object_or_none from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser from orgs.mixins.api import OrgBulkModelViewSet diff --git a/apps/assets/api/mixin.py b/apps/assets/api/mixin.py index 386a1f507..763f025f4 100644 --- a/apps/assets/api/mixin.py +++ b/apps/assets/api/mixin.py @@ -1,5 +1,6 @@ from typing import List +from common.utils.common import timeit from assets.models import Node, Asset from assets.pagination import AssetLimitOffsetPagination from common.utils import lazyproperty @@ -7,6 +8,8 @@ from assets.utils import get_node, is_query_node_all_assets class SerializeToTreeNodeMixin: + + @timeit def serialize_nodes(self, nodes: List[Node], with_asset_amount=False): if with_asset_amount: def _name(node: Node): @@ -43,6 +46,7 @@ class SerializeToTreeNodeMixin: return platform return default + @timeit def serialize_assets(self, assets, node_key=None): if node_key is None: get_pid = lambda asset: getattr(asset, 'parent_key', '') diff --git a/apps/assets/api/node.py b/apps/assets/api/node.py index a64326042..c793ad384 100644 --- a/apps/assets/api/node.py +++ b/apps/assets/api/node.py @@ -17,12 +17,9 @@ from common.const.signals import PRE_REMOVE, POST_REMOVE from assets.models import Asset from common.utils import get_logger, get_object_or_none from common.tree import TreeNodeSerializer -from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY from orgs.mixins.api import OrgModelViewSet from orgs.mixins import generics -from orgs.lock import org_level_transaction_lock from orgs.utils import current_org -from assets.tasks import check_node_assets_amount_task from ..hands import IsOrgAdmin from ..models import Node from ..tasks import ( @@ -31,6 +28,7 @@ from ..tasks import ( ) from .. import serializers from .mixin import SerializeToTreeNodeMixin +from assets.locks import NodeTreeUpdateLock logger = get_logger(__file__) @@ -50,11 +48,6 @@ class NodeViewSet(OrgModelViewSet): permission_classes = (IsOrgAdmin,) serializer_class = serializers.NodeSerializer - @action(methods=[POST], detail=False, url_name='launch-check-assets-amount-task') - def launch_check_assets_amount_task(self, request): - task = check_node_assets_amount_task.delay(current_org.id) - return Response(data={'task': task.id}) - # 仅支持根节点指直接创建,子节点下的节点需要通过children接口创建 def perform_create(self, serializer): child_key = Node.org_root().get_next_child_key() @@ -184,9 +177,9 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi): if not include_assets: return [] assets = self.instance.get_assets().only( - "id", "hostname", "ip", "os", - "org_id", "protocols", "is_active" - ) + "id", "hostname", "ip", "os", "platform_id", + "org_id", "protocols", "is_active", + ).prefetch_related('platform') return self.serialize_assets(assets, self.instance.key) @@ -219,8 +212,6 @@ class NodeAddChildrenApi(generics.UpdateAPIView): return Response("OK") -@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch') -@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put') class NodeAddAssetsApi(generics.UpdateAPIView): model = Node serializer_class = serializers.NodeAssetsSerializer @@ -233,8 +224,6 @@ class NodeAddAssetsApi(generics.UpdateAPIView): instance.assets.add(*tuple(assets)) -@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch') -@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put') class NodeRemoveAssetsApi(generics.UpdateAPIView): model = Node serializer_class = serializers.NodeAssetsSerializer @@ -251,8 +240,6 @@ class NodeRemoveAssetsApi(generics.UpdateAPIView): Node.org_root().assets.add(*orphan_assets) -@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch') -@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put') class MoveAssetsToNodeApi(generics.UpdateAPIView): model = Node serializer_class = serializers.NodeAssetsSerializer diff --git a/apps/assets/locks.py b/apps/assets/locks.py new file mode 100644 index 000000000..b80db3ff8 --- /dev/null +++ b/apps/assets/locks.py @@ -0,0 +1,21 @@ +from orgs.utils import current_org +from common.utils.lock import DistributedLock + + +class NodeTreeUpdateLock(DistributedLock): + name_template = 'assets.node.tree.update.' + + def get_name(self): + if current_org: + org_id = current_org.id + else: + org_id = 'current_org_is_null' + name = self.name_template.format( + org_id=org_id + ) + return name + + def __init__(self, blocking=True): + name = self.get_name() + super().__init__(name=name, blocking=blocking, + release_lock_on_transaction_commit=True) diff --git a/apps/assets/migrations/0066_remove_node_assets_amount.py b/apps/assets/migrations/0066_remove_node_assets_amount.py new file mode 100644 index 000000000..5d7044179 --- /dev/null +++ b/apps/assets/migrations/0066_remove_node_assets_amount.py @@ -0,0 +1,17 @@ +# Generated by Django 3.1 on 2021-02-04 09:49 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('assets', '0065_auto_20210121_1549'), + ] + + operations = [ + migrations.RemoveField( + model_name='node', + name='assets_amount', + ), + ] diff --git a/apps/assets/models/asset.py b/apps/assets/models/asset.py index d4c787bbc..5d133f40d 100644 --- a/apps/assets/models/asset.py +++ b/apps/assets/models/asset.py @@ -17,7 +17,7 @@ from orgs.mixins.models import OrgModelMixin, OrgManager from .base import ConnectivityMixin from .utils import Connectivity -__all__ = ['Asset', 'ProtocolsMixin', 'Platform'] +__all__ = ['Asset', 'ProtocolsMixin', 'Platform', 'AssetQuerySet'] logger = logging.getLogger(__name__) @@ -41,13 +41,6 @@ def default_node(): class AssetManager(OrgManager): - def get_queryset(self): - return super().get_queryset().annotate( - platform_base=models.F('platform__base') - ) - - -class AssetOrgManager(OrgManager): pass @@ -230,7 +223,6 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin): comment = models.TextField(default='', blank=True, verbose_name=_('Comment')) objects = AssetManager.from_queryset(AssetQuerySet)() - org_objects = AssetOrgManager.from_queryset(AssetQuerySet)() _connectivity = None def __str__(self): diff --git a/apps/assets/models/base.py b/apps/assets/models/base.py index b7239da75..094f029bc 100644 --- a/apps/assets/models/base.py +++ b/apps/assets/models/base.py @@ -11,6 +11,7 @@ from django.db import models from django.utils.translation import ugettext_lazy as _ from django.conf import settings +from common.utils.common import timeit from common.utils import ( ssh_key_string_to_obj, ssh_key_gen, get_logger, lazyproperty ) diff --git a/apps/assets/models/favorite_asset.py b/apps/assets/models/favorite_asset.py index b176c8105..3abc69c8c 100644 --- a/apps/assets/models/favorite_asset.py +++ b/apps/assets/models/favorite_asset.py @@ -18,15 +18,3 @@ class FavoriteAsset(CommonModelMixin): @classmethod def get_user_favorite_assets_id(cls, user): return cls.objects.filter(user=user).values_list('asset', flat=True) - - @classmethod - def get_user_favorite_assets(cls, user, asset_perms_id=None): - from assets.models import Asset - from perms.utils.asset.user_permission import get_user_granted_all_assets - asset_ids = get_user_granted_all_assets( - user, - via_mapping_node=False, - asset_perms_id=asset_perms_id - ).values_list('id', flat=True) - query_name = cls.asset.field.related_query_name() - return Asset.org_objects.filter(**{f'{query_name}__user_id': user.id}, id__in=asset_ids).distinct() diff --git a/apps/assets/models/node.py b/apps/assets/models/node.py index 58e267f70..e5b53eb45 100644 --- a/apps/assets/models/node.py +++ b/apps/assets/models/node.py @@ -1,23 +1,32 @@ # -*- coding: utf-8 -*- # -import uuid import re +import time +import uuid +import threading +import os +import time +import uuid +from collections import defaultdict from django.db import models, transaction -from django.db.models import Q +from django.db.models import Q, Manager from django.db.utils import IntegrityError from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext from django.db.transaction import atomic +from django.core.cache import cache +from common.utils.lock import DistributedLock +from common.utils.common import timeit +from common.db.models import output_as_string from common.utils import get_logger -from common.utils.common import lazyproperty from orgs.mixins.models import OrgModelMixin, OrgManager from orgs.utils import get_current_org, tmp_to_org from orgs.models import Organization -__all__ = ['Node', 'FamilyMixin', 'compute_parent_key'] +__all__ = ['Node', 'FamilyMixin', 'compute_parent_key', 'NodeQuerySet'] logger = get_logger(__name__) @@ -247,9 +256,125 @@ class FamilyMixin: return [*tuple(ancestors), self, *tuple(children)] -class NodeAssetsMixin: +class NodeAllAssetsMappingMixin: + # Use a new plan + + # { org_id: { node_key: [ asset1_id, asset2_id ] } } + orgid_nodekey_assetsid_mapping = defaultdict(dict) + + @classmethod + def get_node_all_assets_id_mapping(cls, org_id): + _mapping = cls.get_node_all_assets_id_mapping_from_memory(org_id) + if _mapping: + return _mapping + + _mapping = cls.get_node_all_assets_id_mapping_from_cache_or_generate_to_cache(org_id) + cls.set_node_all_assets_id_mapping_to_memory(org_id, mapping=_mapping) + return _mapping + + # from memory + @classmethod + def get_node_all_assets_id_mapping_from_memory(cls, org_id): + mapping = cls.orgid_nodekey_assetsid_mapping.get(org_id, {}) + return mapping + + @classmethod + def set_node_all_assets_id_mapping_to_memory(cls, org_id, mapping): + cls.orgid_nodekey_assetsid_mapping[org_id] = mapping + + @classmethod + def expire_node_all_assets_id_mapping_from_memory(cls, org_id): + org_id = str(org_id) + cls.orgid_nodekey_assetsid_mapping.pop(org_id, None) + + # get order: from memory -> (from cache -> to generate) + @classmethod + def get_node_all_assets_id_mapping_from_cache_or_generate_to_cache(cls, org_id): + mapping = cls.get_node_all_assets_id_mapping_from_cache(org_id) + if mapping: + return mapping + + lock_key = f'KEY_LOCK_GENERATE_ORG_{org_id}_NODE_ALL_ASSETS_ID_MAPPING' + logger.info(f'Thread[{threading.get_ident()}] acquiring lock[{lock_key}] ...') + with DistributedLock(lock_key): + logger.info(f'Thread[{threading.get_ident()}] acquire lock[{lock_key}] ok') + # 这里使用无限期锁,原因是如果这里卡住了,就卡在数据库了,说明 + # 数据库繁忙,所以不应该再有线程执行这个操作,使数据库忙上加忙 + + # 这里最好先判断内存中有没有,防止同一进程的多个线程重复从 cache 中获取数据, + # 但逻辑过于繁琐,直接判断 cache 吧 + _mapping = cls.get_node_all_assets_id_mapping_from_cache(org_id) + if _mapping: + return _mapping + + _mapping = cls.generate_node_all_assets_id_mapping(org_id) + cls.set_node_all_assets_id_mapping_to_cache(org_id=org_id, mapping=_mapping) + return _mapping + + @classmethod + def get_node_all_assets_id_mapping_from_cache(cls, org_id): + cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id) + mapping = cache.get(cache_key) + return mapping + + @classmethod + def set_node_all_assets_id_mapping_to_cache(cls, org_id, mapping): + cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id) + cache.set(cache_key, mapping, timeout=None) + + @classmethod + def expire_node_all_assets_id_mapping_from_cache(cls, org_id): + cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id) + cache.delete(cache_key) + + @staticmethod + def _get_cache_key_for_node_all_assets_id_mapping(org_id): + return 'ASSETS_ORG_NODE_ALL_ASSETS_ID_MAPPING_{}'.format(org_id) + + @classmethod + def generate_node_all_assets_id_mapping(cls, org_id): + from .asset import Asset + + t1 = time.time() + with tmp_to_org(org_id): + nodes_id_key = Node.objects.filter(org_id=org_id) \ + .annotate(char_id=output_as_string('id')) \ + .values_list('char_id', 'key') + + # * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢) + nodes_assets_id = Asset.nodes.through.objects.all() \ + .annotate(char_node_id=output_as_string('node_id')) \ + .annotate(char_asset_id=output_as_string('asset_id')) \ + .values_list('char_node_id', 'char_asset_id') + + node_id_ancestor_keys_mapping = { + node_id: cls.get_node_ancestor_keys(node_key, with_self=True) + for node_id, node_key in nodes_id_key + } + + nodeid_assetsid_mapping = defaultdict(set) + for node_id, asset_id in nodes_assets_id: + nodeid_assetsid_mapping[node_id].add(asset_id) + + t2 = time.time() + + mapping = defaultdict(set) + for node_id, node_key in nodes_id_key: + assets_id = nodeid_assetsid_mapping[node_id] + node_ancestor_keys = node_id_ancestor_keys_mapping[node_id] + for ancestor_key in node_ancestor_keys: + mapping[ancestor_key].update(assets_id) + + t3 = time.time() + logger.debug('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2-t1, t3-t2)) + return mapping + + +class NodeAssetsMixin(NodeAllAssetsMappingMixin): + org_id: str key = '' id = None + objects: Manager def get_all_assets(self): from .asset import Asset @@ -263,8 +388,7 @@ class NodeAssetsMixin: # 可是 startswith 会导致表关联时 Asset 索引失效 from .asset import Asset node_ids = cls.objects.filter( - Q(key__startswith=f'{key}:') | - Q(key=key) + Q(key__startswith=f'{key}:') | Q(key=key) ).values_list('id', flat=True).distinct() assets = Asset.objects.filter( nodes__id__in=list(node_ids) @@ -283,29 +407,39 @@ class NodeAssetsMixin: return self.get_all_assets().valid() @classmethod - def get_nodes_all_assets_ids(cls, nodes_keys): - assets_ids = cls.get_nodes_all_assets(nodes_keys).values_list('id', flat=True) + def get_nodes_all_assets_ids_by_keys(cls, nodes_keys): + nodes = Node.objects.filter(key__in=nodes_keys) + assets_ids = cls.get_nodes_all_assets(*nodes).values_list('id', flat=True) return assets_ids @classmethod - def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None): + def get_nodes_all_assets(cls, *nodes): from .asset import Asset - nodes_keys = cls.clean_children_keys(nodes_keys) - q = Q() - node_ids = () - for key in nodes_keys: - q |= Q(key__startswith=f'{key}:') - q |= Q(key=key) - if q: - node_ids = Node.objects.filter(q).distinct().values_list('id', flat=True) + node_ids = set() + descendant_node_query = Q() + for n in nodes: + node_ids.add(n.id) + descendant_node_query |= Q(key__istartswith=f'{n.key}:') + if descendant_node_query: + _ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True) + node_ids.update(_ids) + return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct() - q = Q(nodes__id__in=list(node_ids)) - if extra_assets_ids: - q |= Q(id__in=extra_assets_ids) - if q: - return Asset.org_objects.filter(q).distinct() - else: - return Asset.objects.none() + @property + def assets_amount(self): + assets_id = self.get_all_assets_id() + return len(assets_id) + + def get_all_assets_id(self): + assets_id = self.get_all_assets_id_by_node_key(org_id=self.org_id, node_key=self.key) + return set(assets_id) + + @classmethod + def get_all_assets_id_by_node_key(cls, org_id, node_key): + org_id = str(org_id) + nodekey_assetsid_mapping = cls.get_node_all_assets_id_mapping(org_id) + assets_id = nodekey_assetsid_mapping.get(node_key, []) + return set(assets_id) class SomeNodesMixin: @@ -416,7 +550,6 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin): date_create = models.DateTimeField(auto_now_add=True) parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"), db_index=True, default='') - assets_amount = models.IntegerField(default=0) objects = OrgManager.from_queryset(NodeQuerySet)() is_node = True diff --git a/apps/assets/models/user.py b/apps/assets/models/user.py index 5b800311b..885543796 100644 --- a/apps/assets/models/user.py +++ b/apps/assets/models/user.py @@ -199,7 +199,7 @@ class SystemUser(BaseUser): from assets.models import Node nodes_keys = self.nodes.all().values_list('key', flat=True) assets_ids = set(self.assets.all().values_list('id', flat=True)) - nodes_assets_ids = Node.get_nodes_all_assets_ids(nodes_keys) + nodes_assets_ids = Node.get_nodes_all_assets_ids_by_keys(nodes_keys) assets_ids.update(nodes_assets_ids) assets = Asset.objects.filter(id__in=assets_ids) return assets diff --git a/apps/assets/serializers/asset.py b/apps/assets/serializers/asset.py index a8ce0f3ee..7efe30186 100644 --- a/apps/assets/serializers/asset.py +++ b/apps/assets/serializers/asset.py @@ -111,7 +111,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer): @classmethod def setup_eager_loading(cls, queryset): """ Perform necessary eager loading of data. """ - queryset = queryset.select_related('admin_user', 'domain', 'platform') + queryset = queryset.prefetch_related('admin_user', 'domain', 'platform') queryset = queryset.prefetch_related('nodes', 'labels') return queryset @@ -166,13 +166,6 @@ class AssetDisplaySerializer(AssetSerializer): 'connectivity', ] - @classmethod - def setup_eager_loading(cls, queryset): - queryset = super().setup_eager_loading(queryset) - queryset = queryset\ - .annotate(admin_user_username=F('admin_user__username')) - return queryset - class PlatformSerializer(serializers.ModelSerializer): meta = serializers.DictField(required=False, allow_null=True) diff --git a/apps/assets/signals_handler/__init__.py b/apps/assets/signals_handler/__init__.py new file mode 100644 index 000000000..0c3980565 --- /dev/null +++ b/apps/assets/signals_handler/__init__.py @@ -0,0 +1,2 @@ +from .common import * +from .maintain_nodes_tree import * diff --git a/apps/assets/signals_handler.py b/apps/assets/signals_handler/common.py similarity index 60% rename from apps/assets/signals_handler.py rename to apps/assets/signals_handler/common.py index 061d7d84c..6625e493e 100644 --- a/apps/assets/signals_handler.py +++ b/apps/assets/signals_handler/common.py @@ -1,21 +1,17 @@ # -*- coding: utf-8 -*- # -from operator import add, sub - -from assets.utils import is_asset_exists_in_node from django.db.models.signals import ( post_save, m2m_changed, pre_delete, post_delete, pre_save ) -from django.db.models import Q, F from django.dispatch import receiver from common.exceptions import M2MReverseNotAllowed -from common.const.signals import PRE_ADD, POST_ADD, POST_REMOVE, PRE_CLEAR, PRE_REMOVE +from common.const.signals import POST_ADD, POST_REMOVE, PRE_REMOVE from common.utils import get_logger from common.decorator import on_transaction_commit -from .models import Asset, SystemUser, Node, compute_parent_key +from assets.models import Asset, SystemUser, Node from users.models import User -from .tasks import ( +from assets.tasks import ( update_assets_hardware_info_util, test_asset_connectivity_util, push_system_user_to_assets_manual, @@ -23,7 +19,6 @@ from .tasks import ( add_nodes_assets_to_system_users ) - logger = get_logger(__file__) @@ -202,134 +197,6 @@ def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs): m2m_model.objects.bulk_create(to_create) -def _update_node_assets_amount(node: Node, asset_pk_set: set, operator=add): - """ - 一个节点与多个资产关系变化时,更新计数 - - :param node: 节点实例 - :param asset_pk_set: 资产的`id`集合, 内部不会修改该值 - :param operator: 操作 - * -> Node - # -> Asset - - * [3] - / \ - * * [2] - / \ - * * [1] - / / \ - * [a] # # [b] - - """ - # 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key` - ancestor_keys = node.get_ancestor_keys(with_self=True) - ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key') - to_update = [] - for ancestor in ancestors: - # 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3] - # 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的 - # 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量 - - asset_pk_set -= set(Asset.objects.filter( - id__in=asset_pk_set - ).filter( - Q(nodes__key__istartswith=f'{ancestor.key}:') | - Q(nodes__key=ancestor.key) - ).distinct().values_list('id', flat=True)) - if not asset_pk_set: - # 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量 - # 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用 - # 处理了 - break - ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set)) - to_update.append(ancestor) - Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key')) - - -def _remove_ancestor_keys(ancestor_key, tree_set): - # 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环 - # 判断是否在集合里,来区分是否已被处理过 - while ancestor_key and ancestor_key in tree_set: - tree_set.remove(ancestor_key) - ancestor_key = compute_parent_key(ancestor_key) - - -def _update_nodes_asset_amount(node_keys, asset_pk, operator): - """ - 一个资产与多个节点关系变化时,更新计数 - - :param node_keys: 节点 id 的集合 - :param asset_pk: 资产 id - :param operator: 操作 - """ - - # 所有相关节点的祖先节点,组成一棵局部树 - ancestor_keys = set() - for key in node_keys: - ancestor_keys.update(Node.get_node_ancestor_keys(key)) - - # 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉 - node_keys -= ancestor_keys - - to_update_keys = [] - for key in node_keys: - # 遍历相关节点,处理它及其祖先节点 - # 查询该节点是否包含待处理资产 - exists = is_asset_exists_in_node(asset_pk, key) - parent_key = compute_parent_key(key) - - if exists: - # 如果资产在该节点,那么他及其祖先节点都不用处理 - _remove_ancestor_keys(parent_key, ancestor_keys) - continue - else: - # 不存在,要更新本节点 - to_update_keys.append(key) - # 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环 - # 判断是否在集合里,来区分是否已被处理过 - while parent_key and parent_key in ancestor_keys: - exists = is_asset_exists_in_node(asset_pk, parent_key) - if exists: - _remove_ancestor_keys(parent_key, ancestor_keys) - break - else: - to_update_keys.append(parent_key) - ancestor_keys.remove(parent_key) - parent_key = compute_parent_key(parent_key) - - Node.objects.filter(key__in=to_update_keys).update( - assets_amount=operator(F('assets_amount'), 1) - ) - - -@receiver(m2m_changed, sender=Asset.nodes.through) -def update_nodes_assets_amount(action, instance, reverse, pk_set, **kwargs): - # 不允许 `pre_clear` ,因为该信号没有 `pk_set` - # [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed) - refused = (PRE_CLEAR,) - if action in refused: - raise ValueError - - mapper = { - PRE_ADD: add, - POST_REMOVE: sub - } - if action not in mapper: - return - - operator = mapper[action] - - if reverse: - node: Node = instance - asset_pk_set = set(pk_set) - _update_node_assets_amount(node, asset_pk_set, operator) - else: - asset_pk = instance.id - # 与资产直接关联的节点 - node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True)) - _update_nodes_asset_amount(node_keys, asset_pk, operator) - - RELATED_NODE_IDS = '_related_node_ids' diff --git a/apps/assets/signals_handler/maintain_nodes_tree.py b/apps/assets/signals_handler/maintain_nodes_tree.py new file mode 100644 index 000000000..d6615c885 --- /dev/null +++ b/apps/assets/signals_handler/maintain_nodes_tree.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# +import os +import threading + +from django.db.models.signals import ( + m2m_changed, post_save, post_delete +) +from django.dispatch import receiver +from django.utils.functional import LazyObject + +from common.signals import django_ready +from common.utils.connection import RedisPubSub +from common.utils import get_logger +from assets.models import Asset, Node + + +logger = get_logger(__file__) + +# clear node assets mapping for memory +# ------------------------------------ + + +def get_node_assets_mapping_for_memory_pub_sub(): + return RedisPubSub('fm.node_all_assets_id_memory_mapping') + + +class NodeAssetsMappingForMemoryPubSub(LazyObject): + def _setup(self): + self._wrapped = get_node_assets_mapping_for_memory_pub_sub() + + +node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub() + + +def expire_node_assets_mapping_for_memory(org_id): + # 所有进程清除(自己的 memory 数据) + org_id = str(org_id) + node_assets_mapping_for_memory_pub_sub.publish(org_id) + # 当前进程清除(cache 数据) + logger.debug( + "Expire node assets id mapping from cache of org={}, pid={}" + "".format(org_id, os.getpid()) + ) + Node.expire_node_all_assets_id_mapping_from_cache(org_id) + + +@receiver(post_save, sender=Node) +def on_node_post_create(sender, instance, created, update_fields, **kwargs): + if created: + need_expire = True + elif update_fields and 'key' in update_fields: + need_expire = True + else: + need_expire = False + + if need_expire: + expire_node_assets_mapping_for_memory(instance.org_id) + + +@receiver(post_delete, sender=Node) +def on_node_post_delete(sender, instance, **kwargs): + expire_node_assets_mapping_for_memory(instance.org_id) + + +@receiver(m2m_changed, sender=Asset.nodes.through) +def on_node_asset_change(sender, instance, **kwargs): + expire_node_assets_mapping_for_memory(instance.org_id) + + +@receiver(django_ready) +def subscribe_node_assets_mapping_expire(sender, **kwargs): + logger.debug("Start subscribe for expire node assets id mapping from memory") + + def keep_subscribe(): + subscribe = node_assets_mapping_for_memory_pub_sub.subscribe() + for message in subscribe.listen(): + if message["type"] != "message": + continue + org_id = message['data'].decode() + Node.expire_node_all_assets_id_mapping_from_memory(org_id) + logger.debug( + "Expire node assets id mapping from memory of org={}, pid={}" + "".format(str(org_id), os.getpid()) + ) + t = threading.Thread(target=keep_subscribe) + t.daemon = True + t.start() diff --git a/apps/assets/tasks/common.py b/apps/assets/tasks/common.py index b743300e1..5a92ec039 100644 --- a/apps/assets/tasks/common.py +++ b/apps/assets/tasks/common.py @@ -12,6 +12,7 @@ __all__ = ['add_nodes_assets_to_system_users'] @tmp_to_root_org() def add_nodes_assets_to_system_users(nodes_keys, system_users): from ..models import Node - assets = Node.get_nodes_all_assets(nodes_keys).values_list('id', flat=True) + nodes = Node.objects.filter(key__in=nodes_keys) + assets = Node.get_nodes_all_assets(*nodes) for system_user in system_users: system_user.assets.add(*tuple(assets)) diff --git a/apps/assets/tasks/gather_asset_users.py b/apps/assets/tasks/gather_asset_users.py index 5d8372451..0187a29aa 100644 --- a/apps/assets/tasks/gather_asset_users.py +++ b/apps/assets/tasks/gather_asset_users.py @@ -141,7 +141,8 @@ def gather_asset_users(assets, task_name=None): @shared_task(queue="ansible") def gather_nodes_asset_users(nodes_key): - assets = Node.get_nodes_all_assets(nodes_key) + nodes = Node.objects.filter(key__in=nodes_key) + assets = Node.get_nodes_all_assets(*nodes) assets_groups_by_100 = [assets[i:i+100] for i in range(0, len(assets), 100)] for _assets in assets_groups_by_100: gather_asset_users(_assets) diff --git a/apps/assets/tasks/nodes_amount.py b/apps/assets/tasks/nodes_amount.py index e1e437797..e69de29bb 100644 --- a/apps/assets/tasks/nodes_amount.py +++ b/apps/assets/tasks/nodes_amount.py @@ -1,27 +0,0 @@ -from celery import shared_task -from django.utils.translation import gettext_lazy as _ - -from orgs.models import Organization -from orgs.utils import tmp_to_org -from ops.celery.decorator import register_as_period_task -from assets.utils import check_node_assets_amount - -from common.utils.lock import AcquireFailed -from common.utils import get_logger - -logger = get_logger(__file__) - - -@shared_task(queue='celery_heavy_tasks') -def check_node_assets_amount_task(org_id=Organization.ROOT_ID): - try: - with tmp_to_org(Organization.get_instance(org_id)): - check_node_assets_amount() - except AcquireFailed: - logger.error(_('The task of self-checking is already running and cannot be started repeatedly')) - - -@register_as_period_task(crontab='0 2 * * *') -@shared_task(queue='celery_heavy_tasks') -def check_node_assets_amount_period_task(): - check_node_assets_amount_task() diff --git a/apps/assets/tests/tree.py b/apps/assets/tests/tree.py new file mode 100644 index 000000000..99bfd2275 --- /dev/null +++ b/apps/assets/tests/tree.py @@ -0,0 +1,33 @@ +from assets.tree import Tree + + +def test(): + from orgs.models import Organization + from assets.models import Node, Asset + import time + Organization.objects.get(id='1863cf22-f666-474e-94aa-935fe175203c').change_to() + + t1 = time.time() + nodes = list(Node.objects.exclude(key__startswith='-').only('id', 'key', 'parent_key')) + node_asset_id_pairs = Asset.nodes.through.objects.all().values_list('node_id', 'asset_id') + t2 = time.time() + node_asset_id_pairs = list(node_asset_id_pairs) + tree = Tree(nodes, node_asset_id_pairs) + tree.build_tree() + tree.nodes = None + tree.node_asset_id_pairs = None + import pickle + d = pickle.dumps(tree) + print('------------', len(d)) + return tree + tree.compute_tree_node_assets_amount() + + print(f'校对算法准确性 ......') + for node in nodes: + tree_node = tree.key_tree_node_mapper[node.key] + if tree_node.assets_amount != node.assets_amount: + print(f'ERROR: {tree_node.assets_amount} {node.assets_amount}') + # print(f'OK {tree_node.asset_amount} {node.assets_amount}') + + print(f'数据库时间: {t2 - t1}') + return tree \ No newline at end of file diff --git a/apps/assets/urls/api_urls.py b/apps/assets/urls/api_urls.py index 707a8e73d..5a5e6d803 100644 --- a/apps/assets/urls/api_urls.py +++ b/apps/assets/urls/api_urls.py @@ -2,7 +2,6 @@ from django.urls import path, re_path from rest_framework_nested import routers from rest_framework_bulk.routes import BulkRouter -from django.db.transaction import non_atomic_requests from common import api as capi @@ -57,9 +56,9 @@ urlpatterns = [ path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'), path('nodes//children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'), path('nodes//assets/', api.NodeAssetsApi.as_view(), name='node-assets'), - path('nodes//assets/add/', non_atomic_requests(api.NodeAddAssetsApi.as_view()), name='node-add-assets'), - path('nodes//assets/replace/', non_atomic_requests(api.MoveAssetsToNodeApi.as_view()), name='node-replace-assets'), - path('nodes//assets/remove/', non_atomic_requests(api.NodeRemoveAssetsApi.as_view()), name='node-remove-assets'), + path('nodes//assets/add/', api.NodeAddAssetsApi.as_view(), name='node-add-assets'), + path('nodes//assets/replace/', api.MoveAssetsToNodeApi.as_view(), name='node-replace-assets'), + path('nodes//assets/remove/', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'), path('nodes//tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'), path('gateways//test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'), diff --git a/apps/assets/utils.py b/apps/assets/utils.py index 2805ac034..343fa704b 100644 --- a/apps/assets/utils.py +++ b/apps/assets/utils.py @@ -1,43 +1,16 @@ # ~*~ coding: utf-8 ~*~ # -import time - -from django.db.models import Q - -from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none -from common.utils.lock import DistributedLock +from collections import defaultdict +from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit from common.http import is_true -from .models import Asset, Node +from common.struct import Stack +from common.db.models import output_as_string +from .models import Node logger = get_logger(__file__) -@DistributedLock(name="assets.node.check_node_assets_amount", blocking=False) -def check_node_assets_amount(): - for node in Node.objects.all(): - logger.info(f'Check node assets amount: {node}') - assets_amount = Asset.objects.filter( - Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node) - ).distinct().count() - - if node.assets_amount != assets_amount: - logger.warn(f'Node wrong assets amount ' - f'{node.assets_amount} right is {assets_amount}') - node.assets_amount = assets_amount - node.save() - # 防止自检程序给数据库的压力太大 - time.sleep(0.1) - - -def is_asset_exists_in_node(asset_pk, node_key): - return Asset.objects.filter( - id=asset_pk - ).filter( - Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key) - ).exists() - - def is_query_node_all_assets(request): request = request query_all_arg = request.query_params.get('all', 'true') @@ -57,3 +30,79 @@ def get_node(request): else: node = get_object_or_none(Node, key=node_id) return node + + +class NodeAssetsInfo: + __slots__ = ('key', 'assets_amount', 'assets') + + def __init__(self, key, assets_amount, assets): + self.key = key + self.assets_amount = assets_amount + self.assets = assets + + def __str__(self): + return self.key + + +class NodeAssetsUtil: + def __init__(self, nodes, nodekey_assetsid_mapper): + """ + :param nodes: 节点 + :param nodekey_assetsid_mapper: 节点直接资产id的映射 {"key1": set(), "key2": set()} + """ + self.nodes = nodes + # node_id --> set(asset_id1, asset_id2) + self.nodekey_assetsid_mapper = nodekey_assetsid_mapper + self.nodekey_assetsinfo_mapper = {} + + @timeit + def generate(self): + # 准备排序好的资产信息数据 + infos = [] + for node in self.nodes: + assets = self.nodekey_assetsid_mapper.get(node.key, set()) + info = NodeAssetsInfo(key=node.key, assets_amount=0, assets=assets) + infos.append(info) + infos = sorted(infos, key=lambda i: [int(i) for i in i.key.split(':')]) + # 这个守卫需要添加一下,避免最后一个无法出栈 + guarder = NodeAssetsInfo(key='', assets_amount=0, assets=set()) + infos.append(guarder) + + stack = Stack() + for info in infos: + # 如果栈顶的不是这个节点的父祖节点,那么可以出栈了,可以计算资产数量了 + while stack.top and not info.key.startswith(f'{stack.top.key}:'): + pop_info = stack.pop() + pop_info.assets_amount = len(pop_info.assets) + self.nodekey_assetsinfo_mapper[pop_info.key] = pop_info + if not stack.top: + continue + stack.top.assets.update(pop_info.assets) + stack.push(info) + + def get_assets_by_key(self, key): + info = self.nodekey_assetsinfo_mapper[key] + return info['assets'] + + def get_assets_amount(self, key): + info = self.nodekey_assetsinfo_mapper[key] + return info.assets_amount + + @classmethod + def test_it(cls): + from assets.models import Node, Asset + + nodes = list(Node.objects.all()) + nodes_assets = Asset.nodes.through.objects.all()\ + .annotate(aid=output_as_string('asset_id'))\ + .values_list('node__key', 'aid') + + mapping = defaultdict(set) + for key, asset_id in nodes_assets: + mapping[key].add(asset_id) + + util = cls(nodes, mapping) + util.generate() + return util + + diff --git a/apps/common/const/distributed_lock_key.py b/apps/common/const/distributed_lock_key.py deleted file mode 100644 index 735781841..000000000 --- a/apps/common/const/distributed_lock_key.py +++ /dev/null @@ -1,2 +0,0 @@ -UPDATE_NODE_TREE_LOCK_KEY = 'org_level_transaction_lock_{org_id}_assets_update_node_tree' -UPDATE_MAPPING_NODE_TASK_LOCK_KEY = 'org_level_transaction_lock_{user_id}_update_mapping_node_task' diff --git a/apps/common/db/models.py b/apps/common/db/models.py index 502df31e9..df5d6a46d 100644 --- a/apps/common/db/models.py +++ b/apps/common/db/models.py @@ -82,3 +82,7 @@ class JMSModel(JMSBaseModel): def concated_display(name1, name2): return Concat(F(name1), Value('('), F(name2), Value(')')) + + +def output_as_string(field_name): + return ExpressionWrapper(F(field_name), output_field=CharField()) diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index 8bc7377e5..f6808b0ac 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -254,3 +254,22 @@ def get_disk_usage(): mount_points = [p.mountpoint for p in partitions] usages = {p: psutil.disk_usage(p) for p in mount_points} return usages + + +class Time: + def __init__(self): + self._timestamps = [] + self._msgs = [] + + def begin(self): + self._timestamps.append(time.time()) + + def time(self, msg): + self._timestamps.append(time.time()) + self._msgs.append(msg) + + def print(self): + last, *timestamps = self._timestamps + for timestamp, msg in zip(timestamps, self._msgs): + logger.debug(f'TIME_IT: {msg} {timestamp-last}') + last = timestamp diff --git a/apps/common/utils/lock.py b/apps/common/utils/lock.py index 04ee1520f..d7d7acbed 100644 --- a/apps/common/utils/lock.py +++ b/apps/common/utils/lock.py @@ -1,8 +1,9 @@ from functools import wraps import threading -from redis_lock import Lock as RedisLock +from redis_lock import Lock as RedisLock, NotAcquired from redis import Redis +from django.db import transaction from common.utils import get_logger from common.utils.inspect import copy_function_args @@ -16,7 +17,8 @@ class AcquireFailed(RuntimeError): class DistributedLock(RedisLock): - def __init__(self, name, blocking=True, expire=60*2, auto_renewal=True): + def __init__(self, name, blocking=True, expire=None, release_lock_on_transaction_commit=False, + release_raise_exc=False, auto_renewal_seconds=60*2): """ 使用 redis 构造的分布式锁 @@ -25,31 +27,46 @@ class DistributedLock(RedisLock): :param blocking: 该参数只在锁作为装饰器或者 `with` 时有效。 :param expire: - 锁的过期时间,注意不一定是锁到这个时间就释放了,分两种情况 - 当 `auto_renewal=False` 时,锁会释放 - 当 `auto_renewal=True` 时,如果过期之前程序还没释放锁,我们会延长锁的存活时间。 - 这里的作用是防止程序意外终止没有释放锁,导致死锁。 + 锁的过期时间 + :param release_lock_on_transaction_commit: + 是否在当前事务结束后再释放锁 + :param release_raise_exc: + 释放锁时,如果没有持有锁是否抛异常或静默 + :param auto_renewal_seconds: + 当持有一个无限期锁的时候,刷新锁的时间,具体参考 `redis_lock.Lock#auto_renewal` """ self.kwargs_copy = copy_function_args(self.__init__, locals()) redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD) + + if expire is None: + expire = auto_renewal_seconds + auto_renewal = True + else: + auto_renewal = False + super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal) self._blocking = blocking + self._release_lock_on_transaction_commit = release_lock_on_transaction_commit + self._release_raise_exc = release_raise_exc def __enter__(self): thread_id = threading.current_thread().ident - logger.debug(f'DISTRIBUTED_LOCK: attempt to acquire ...') + logger.debug(f'Attempt to acquire global lock: thread {thread_id} lock {self._name}') acquired = self.acquire(blocking=self._blocking) if self._blocking and not acquired: - logger.debug(f'DISTRIBUTED_LOCK: was not acquired , but blocking=True') + logger.debug(f'Not acquired lock, but blocking=True, thread {thread_id} lock {self._name}') raise EnvironmentError("Lock wasn't acquired, but blocking=True") if not acquired: - logger.debug(f'DISTRIBUTED_LOCK: acquire failed') + logger.debug(f'Not acquired the lock, thread {thread_id} lock {self._name}') raise AcquireFailed - logger.debug(f'DISTRIBUTED_LOCK: acquire ok') + logger.debug(f'Acquire lock success, thread {thread_id} lock {self._name}') return self def __exit__(self, exc_type=None, exc_value=None, traceback=None): - self.release() + if self._release_lock_on_transaction_commit: + transaction.on_commit(self.release) + else: + self.release() def __call__(self, func): @wraps(func) @@ -57,5 +74,17 @@ class DistributedLock(RedisLock): # 要创建一个新的锁对象 with self.__class__(**self.kwargs_copy): return func(*args, **kwds) - return inner + + def locked_by_me(self): + if self.locked(): + if self.get_owner_id() == self.id: + return True + return False + + def release(self): + try: + super().release() + except AcquireFailed as e: + if self._release_raise_exc: + raise e diff --git a/apps/orgs/lock.py b/apps/orgs/lock.py deleted file mode 100644 index c129b8bcd..000000000 --- a/apps/orgs/lock.py +++ /dev/null @@ -1,131 +0,0 @@ -from uuid import uuid4 -from functools import wraps - -from django.core.cache import cache -from django.db.transaction import atomic -from rest_framework.request import Request -from rest_framework.exceptions import NotAuthenticated - -from orgs.utils import current_org -from common.exceptions import SomeoneIsDoingThis, Timeout -from common.utils.timezone import dt_formater, now - -# Redis 中锁值得模板,该模板提供了很强的可读性,方便调试与排错 -VALUE_TEMPLATE = '{stage}:{username}:{user_id}:{now}:{rand_str}' - -# 锁的状态 -DOING = 'doing' # 处理中,此状态的锁可以被干掉 -COMMITING = 'commiting' # 提交事务中,此状态很重要,要确保事务在锁消失之前返回了,不要轻易删除该锁 - -client = cache.client.get_client(write=True) - - -""" -将锁的状态从 `doing` 切换到 `commiting` -KEYS[1]: key -ARGV[1]: doingvalue -ARGV[2]: commitingvalue -ARGV[3]: timeout -""" -change_lock_state_to_commiting_lua = ''' -if (redis.call("get", KEYS[1]) == ARGV[1]) -then - return redis.call("set", KEYS[1], ARGV[2], "EX", ARGV[3], "XX") -else - return 0 -end -''' -change_lock_state_to_commiting_lua_obj = client.register_script(change_lock_state_to_commiting_lua) - - -""" -释放锁,两种`value`都要检查`doing`和`commiting` -KEYS[1]: key -ARGV[1]: 两个 `value` 中的其中一个 -ARGV[2]: 两个 `value` 中的其中一个 -""" -release_lua = ''' -if (redis.call("get",KEYS[1]) == ARGV[1] or redis.call("get",KEYS[1]) == ARGV[2]) -then - return redis.call("del",KEYS[1]) -else - return 0 -end -''' -release_lua_obj = client.register_script(release_lua) - - -def acquire(key, value, timeout): - return client.set(key, value, ex=timeout, nx=True) - - -def get(key): - return client.get(key) - - -def change_lock_state_to_commiting(key, doingvalue, commitingvalue, timeout=600): - # 将锁的状态从 `doing` 切换到 `commiting` - return bool(change_lock_state_to_commiting_lua_obj(keys=(key,), args=(doingvalue, commitingvalue, timeout))) - - -def release(key, value1, value2): - # 释放锁,两种`value` `doing`和`commiting` 都要检查 - return release_lua_obj(keys=(key,), args=(value1, value2)) - - -def _generate_value(request: Request, stage=DOING): - # 不支持匿名用户 - user = request.user - if user.is_anonymous: - raise NotAuthenticated - - return VALUE_TEMPLATE.format( - stage=stage, username=user.username, user_id=user.id, - now=dt_formater(now()), rand_str=uuid4() - ) - - -default_wait_msg = SomeoneIsDoingThis.default_detail - - -def org_level_transaction_lock(key, timeout=300, wait_msg=default_wait_msg): - """ - 被装饰的 `View` 必须取消自身的 `ATOMIC_REQUESTS`,因为该装饰器要有事务的完全控制权 - [官网](https://docs.djangoproject.com/en/3.1/topics/db/transactions/#tying-transactions-to-http-requests) - - 1. 获取锁:只有当锁对应的 `key` 不存在时成功获取,`value` 设置为 `doing` - 2. 开启事务:本次请求的事务必须确保在这里开启 - 3. 执行 `View` 体 - 4. `View` 体执行结束未异常,此时事务还未提交 - 5. 检查锁是否过时,过时事务回滚,不过时,重新设置`key`延长`key`有效期,已确保足够时间提交事务,同时把`key`的状态改为`commiting` - 6. 提交事务 - 7. 释放锁,释放的时候会检查`doing`与`commiting`的值,因为删除或者更改锁必须提供与当前锁的`value`相同的值,确保不误删 - [锁参考文章](http://doc.redisfans.com/string/set.html#id2) - """ - - def decorator(fun): - @wraps(fun) - def wrapper(request, *args, **kwargs): - # `key`可能是组织相关的,如果是把组织`id`加上 - _key = key.format(org_id=current_org.id) - doing_value = _generate_value(request) - commiting_value = _generate_value(request, stage=COMMITING) - try: - lock = acquire(_key, doing_value, timeout) - if not lock: - raise SomeoneIsDoingThis(detail=wait_msg) - with atomic(savepoint=False): - ret = fun(request, *args, **kwargs) - # 提交事务前,检查一下锁是否还在 - # 锁在的话,更新锁的状态为 `commiting`,延长锁时间,确保事务提交 - # 锁不在的话回滚 - ok = change_lock_state_to_commiting(_key, doing_value, commiting_value) - if not ok: - # 超时或者被中断了 - raise Timeout - return ret - finally: - # 释放锁,锁的两个值都要尝试,不确定异常是从什么位置抛出的 - release(_key, commiting_value, doing_value) - return wrapper - return decorator diff --git a/apps/orgs/utils.py b/apps/orgs/utils.py index c10a5dacc..d01ae3f77 100644 --- a/apps/orgs/utils.py +++ b/apps/orgs/utils.py @@ -184,3 +184,8 @@ def org_aware_func(org_arg_name): current_org = LocalProxy(get_current_org) + + +def ensure_in_real_or_default_org(): + if not current_org or current_org.is_root(): + raise ValueError('You must in a real or default org!') diff --git a/apps/perms/api/application/user_permission/common.py b/apps/perms/api/application/user_permission/common.py index 428f6bdc9..272f84378 100644 --- a/apps/perms/api/application/user_permission/common.py +++ b/apps/perms/api/application/user_permission/common.py @@ -13,7 +13,7 @@ from applications.models import Application from perms.utils.application.permission import ( get_application_system_users_id ) -from perms.api.asset.user_permission.mixin import ForAdminMixin, ForUserMixin +from perms.api.asset.user_permission.mixin import RoleAdminMixin, RoleUserMixin from common.permissions import IsOrgAdminOrAppUser from perms.hands import User, SystemUser from perms import serializers @@ -43,11 +43,11 @@ class GrantedApplicationSystemUsersMixin(ListAPIView): return system_users -class UserGrantedApplicationSystemUsersApi(ForAdminMixin, GrantedApplicationSystemUsersMixin): +class UserGrantedApplicationSystemUsersApi(RoleAdminMixin, GrantedApplicationSystemUsersMixin): pass -class MyGrantedApplicationSystemUsersApi(ForUserMixin, GrantedApplicationSystemUsersMixin): +class MyGrantedApplicationSystemUsersApi(RoleUserMixin, GrantedApplicationSystemUsersMixin): pass diff --git a/apps/perms/api/application/user_permission/user_permission_applications.py b/apps/perms/api/application/user_permission/user_permission_applications.py index 2b8b71847..6916f6f29 100644 --- a/apps/perms/api/application/user_permission/user_permission_applications.py +++ b/apps/perms/api/application/user_permission/user_permission_applications.py @@ -8,7 +8,7 @@ from applications.api.mixin import ( SerializeApplicationToTreeNodeMixin ) from perms import serializers -from perms.api.asset.user_permission.mixin import ForAdminMixin, ForUserMixin +from perms.api.asset.user_permission.mixin import RoleAdminMixin, RoleUserMixin from perms.utils.application.user_permission import ( get_user_granted_all_applications ) @@ -34,11 +34,11 @@ class AllGrantedApplicationsMixin(CommonApiMixin, ListAPIView): return queryset.only(*self.only_fields) -class UserAllGrantedApplicationsApi(ForAdminMixin, AllGrantedApplicationsMixin): +class UserAllGrantedApplicationsApi(RoleAdminMixin, AllGrantedApplicationsMixin): pass -class MyAllGrantedApplicationsApi(ForUserMixin, AllGrantedApplicationsMixin): +class MyAllGrantedApplicationsApi(RoleUserMixin, AllGrantedApplicationsMixin): pass diff --git a/apps/perms/api/asset/user_permission/mixin.py b/apps/perms/api/asset/user_permission/mixin.py index 9961a0cd9..5c5db7729 100644 --- a/apps/perms/api/asset/user_permission/mixin.py +++ b/apps/perms/api/asset/user_permission/mixin.py @@ -4,37 +4,23 @@ from rest_framework.request import Request from common.permissions import IsOrgAdminOrAppUser, IsValidUser from common.utils import lazyproperty +from common.http import is_true from orgs.utils import tmp_to_root_org from users.models import User -from perms.models import UserGrantedMappingNode +from perms.utils.asset.user_permission import UserGrantedTreeRefreshController -class UserNodeGrantStatusDispatchMixin: +class PermBaseMixin: + user: User - @staticmethod - def get_mapping_node_by_key(key, user): - return UserGrantedMappingNode.objects.get(key=key, user=user) - - def dispatch_get_data(self, key, user): - status = UserGrantedMappingNode.get_node_granted_status(key, user) - if status == UserGrantedMappingNode.GRANTED_DIRECT: - return self.get_data_on_node_direct_granted(key) - elif status == UserGrantedMappingNode.GRANTED_INDIRECT: - return self.get_data_on_node_indirect_granted(key) - else: - return self.get_data_on_node_not_granted(key) - - def get_data_on_node_direct_granted(self, key): - raise NotImplementedError - - def get_data_on_node_indirect_granted(self, key): - raise NotImplementedError - - def get_data_on_node_not_granted(self, key): - raise NotImplementedError + def get(self, request, *args, **kwargs): + force = is_true(request.query_params.get('rebuild_tree')) + controller = UserGrantedTreeRefreshController(self.user) + controller.refresh_if_need(force) + return super().get(request, *args, **kwargs) -class ForAdminMixin: +class RoleAdminMixin(PermBaseMixin): permission_classes = (IsOrgAdminOrAppUser,) kwargs: dict @@ -44,7 +30,7 @@ class ForAdminMixin: return User.objects.get(id=user_id) -class ForUserMixin: +class RoleUserMixin(PermBaseMixin): permission_classes = (IsValidUser,) request: Request diff --git a/apps/perms/api/asset/user_permission/user_permission_assets.py b/apps/perms/api/asset/user_permission/user_permission_assets.py deleted file mode 100644 index 209b59625..000000000 --- a/apps/perms/api/asset/user_permission/user_permission_assets.py +++ /dev/null @@ -1,156 +0,0 @@ -# -*- coding: utf-8 -*- -# -from perms.api.asset.user_permission.mixin import UserNodeGrantStatusDispatchMixin -from rest_framework.generics import ListAPIView -from rest_framework.response import Response -from rest_framework.request import Request -from django.conf import settings - -from assets.api.mixin import SerializeToTreeNodeMixin -from common.utils import get_logger -from perms.pagination import GrantedAssetLimitOffsetPagination -from assets.models import Asset, Node, FavoriteAsset -from perms import serializers -from perms.utils.asset.user_permission import ( - get_node_all_granted_assets, get_user_direct_granted_assets, - get_user_granted_all_assets -) -from .mixin import ForAdminMixin, ForUserMixin - - -logger = get_logger(__name__) - - -class UserDirectGrantedAssetsApi(ListAPIView): - """ - 用户直接授权的资产的列表,也就是授权规则上直接授权的资产,并非是来自节点的 - """ - serializer_class = serializers.AssetGrantedSerializer - only_fields = serializers.AssetGrantedSerializer.Meta.only_fields - filterset_fields = ['hostname', 'ip', 'id', 'comment'] - search_fields = ['hostname', 'ip', 'comment'] - - def get_queryset(self): - if getattr(self, 'swagger_fake_view', False): - return Asset.objects.none() - user = self.user - assets = get_user_direct_granted_assets(user)\ - .prefetch_related('platform')\ - .only(*self.only_fields) - return assets - - -class UserFavoriteGrantedAssetsApi(ListAPIView): - serializer_class = serializers.AssetGrantedSerializer - only_fields = serializers.AssetGrantedSerializer.Meta.only_fields - filterset_fields = ['hostname', 'ip', 'id', 'comment'] - search_fields = ['hostname', 'ip', 'comment'] - - def get_queryset(self): - if getattr(self, 'swagger_fake_view', False): - return Asset.objects.none() - user = self.user - assets = FavoriteAsset.get_user_favorite_assets(user)\ - .prefetch_related('platform')\ - .only(*self.only_fields) - return assets - - -class AssetsAsTreeMixin(SerializeToTreeNodeMixin): - """ - 将 资产 序列化成树的结构返回 - """ - def list(self, request: Request, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) - if request.query_params.get('search'): - # 如果用户搜索的条件不精准,会导致返回大量的无意义数据。 - # 这里限制一下返回数据的最大条数 - queryset = queryset[:999] - data = self.serialize_assets(queryset, None) - return Response(data=data) - - -class UserDirectGrantedAssetsForAdminApi(ForAdminMixin, UserDirectGrantedAssetsApi): - pass - - -class MyDirectGrantedAssetsApi(ForUserMixin, UserDirectGrantedAssetsApi): - pass - - -class UserFavoriteGrantedAssetsForAdminApi(ForAdminMixin, UserFavoriteGrantedAssetsApi): - pass - - -class MyFavoriteGrantedAssetsApi(ForUserMixin, UserFavoriteGrantedAssetsApi): - pass - - -class UserDirectGrantedAssetsAsTreeForAdminApi(ForAdminMixin, AssetsAsTreeMixin, UserDirectGrantedAssetsApi): - pass - - -class MyUngroupAssetsAsTreeApi(ForUserMixin, AssetsAsTreeMixin, UserDirectGrantedAssetsApi): - def get_queryset(self): - queryset = super().get_queryset() - if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: - queryset = queryset.none() - return queryset - - -class UserAllGrantedAssetsApi(ForAdminMixin, ListAPIView): - only_fields = serializers.AssetGrantedSerializer.Meta.only_fields - serializer_class = serializers.AssetGrantedSerializer - filterset_fields = ['hostname', 'ip', 'id', 'comment'] - search_fields = ['hostname', 'ip', 'comment'] - - def get_queryset(self): - if getattr(self, 'swagger_fake_view', False): - return Asset.objects.none() - queryset = get_user_granted_all_assets(self.user) - queryset = queryset.prefetch_related('platform') - return queryset.only(*self.only_fields) - - -class MyAllGrantedAssetsApi(ForUserMixin, UserAllGrantedAssetsApi): - pass - - -class MyAllAssetsAsTreeApi(ForUserMixin, AssetsAsTreeMixin, UserAllGrantedAssetsApi): - search_fields = ['hostname', 'ip'] - - -class UserGrantedNodeAssetsApi(UserNodeGrantStatusDispatchMixin, ListAPIView): - serializer_class = serializers.AssetGrantedSerializer - only_fields = serializers.AssetGrantedSerializer.Meta.only_fields - filterset_fields = ['hostname', 'ip', 'id', 'comment'] - search_fields = ['hostname', 'ip', 'comment'] - pagination_class = GrantedAssetLimitOffsetPagination - pagination_node: Node - - def get_queryset(self): - if getattr(self, 'swagger_fake_view', False): - return Asset.objects.none() - node_id = self.kwargs.get("node_id") - node = Node.objects.get(id=node_id) - self.pagination_node = node - return self.dispatch_get_data(node.key, self.user) - - def get_data_on_node_direct_granted(self, key): - # 如果这个节点是直接授权的(或者说祖先节点直接授权的), 获取下面的所有资产 - return Node.get_node_all_assets_by_key_v2(key) - - def get_data_on_node_indirect_granted(self, key): - self.pagination_node = self.get_mapping_node_by_key(key, self.user) - return get_node_all_granted_assets(self.user, key) - - def get_data_on_node_not_granted(self, key): - return Asset.objects.none() - - -class UserGrantedNodeAssetsForAdminApi(ForAdminMixin, UserGrantedNodeAssetsApi): - pass - - -class MyGrantedNodeAssetsApi(ForUserMixin, UserGrantedNodeAssetsApi): - pass diff --git a/apps/perms/api/asset/user_permission/user_permission_assets/__init__.py b/apps/perms/api/asset/user_permission/user_permission_assets/__init__.py new file mode 100644 index 000000000..6b274abdd --- /dev/null +++ b/apps/perms/api/asset/user_permission/user_permission_assets/__init__.py @@ -0,0 +1 @@ +from .views import * diff --git a/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py b/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py new file mode 100644 index 000000000..0b92da278 --- /dev/null +++ b/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py @@ -0,0 +1,127 @@ +from rest_framework.response import Response +from rest_framework.request import Request + +from users.models import User +from assets.api.mixin import SerializeToTreeNodeMixin +from common.utils import get_logger +from perms.pagination import NodeGrantedAssetPagination, AllGrantedAssetPagination +from assets.models import Asset, Node +from perms import serializers +from perms.utils.asset.user_permission import UserGrantedAssetsQueryUtils, QuerySetStage + +logger = get_logger(__name__) + + +# 获取数据的 ------------------------------------------------------------ + +class UserDirectGrantedAssetsQuerysetMixin: + only_fields = serializers.AssetGrantedSerializer.Meta.only_fields + user: User + + def get_queryset(self): + if getattr(self, 'swagger_fake_view', False): + return Asset.objects.none() + user = self.user + assets = UserGrantedAssetsQueryUtils(user) \ + .get_direct_granted_assets() \ + .prefetch_related('platform') \ + .only(*self.only_fields) + return assets + + +class UserAllGrantedAssetsQuerysetMixin: + only_fields = serializers.AssetGrantedSerializer.Meta.only_fields + pagination_class = AllGrantedAssetPagination + user: User + + def get_union_queryset(self, qs_stage: QuerySetStage): + if getattr(self, 'swagger_fake_view', False): + return Asset.objects.none() + qs_stage.prefetch_related('platform').only(*self.only_fields) + queryset = UserGrantedAssetsQueryUtils(self.user) \ + .get_all_granted_assets(qs_stage) + return queryset + + +class UserFavoriteGrantedAssetsMixin: + only_fields = serializers.AssetGrantedSerializer.Meta.only_fields + user: User + + def get_union_queryset(self, qs_stage: QuerySetStage): + if getattr(self, 'swagger_fake_view', False): + return Asset.objects.none() + user = self.user + qs_stage.prefetch_related('platform').only(*self.only_fields) + utils = UserGrantedAssetsQueryUtils(user) + assets = utils.get_favorite_assets(qs_stage=qs_stage) + return assets + + +class UserGrantedNodeAssetsMixin: + only_fields = serializers.AssetGrantedSerializer.Meta.only_fields + pagination_class = NodeGrantedAssetPagination + pagination_node: Node + user: User + + def get_union_queryset(self, qs_stage: QuerySetStage): + if getattr(self, 'swagger_fake_view', False): + return Asset.objects.none() + node_id = self.kwargs.get("node_id") + qs_stage.prefetch_related('platform').only(*self.only_fields) + node, assets = UserGrantedAssetsQueryUtils(self.user).get_node_all_assets( + node_id, qs_stage=qs_stage + ) + self.pagination_node = node + return assets + + +# 控制格式的 ---------------------------------------------------- + +class AssetsUnionQuerysetMixin: + def get_queryset_union_prefer(self): + if hasattr(self, 'get_union_queryset'): + # 为了支持 union 查询 + queryset = Asset.objects.all().distinct() + queryset = self.filter_queryset(queryset) + qs_stage = QuerySetStage() + qs_stage.and_with_queryset(queryset) + queryset = self.get_union_queryset(qs_stage) + else: + queryset = self.filter_queryset(self.get_queryset()) + return queryset + + +class AssetsSerializerFormatMixin(AssetsUnionQuerysetMixin): + serializer_class = serializers.AssetGrantedSerializer + filterset_fields = ['hostname', 'ip', 'id', 'comment'] + search_fields = ['hostname', 'ip', 'comment'] + + def list(self, request, *args, **kwargs): + queryset = self.get_queryset_union_prefer() + + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) + + +class AssetsTreeFormatMixin(AssetsUnionQuerysetMixin, SerializeToTreeNodeMixin): + """ + 将 资产 序列化成树的结构返回 + """ + + def list(self, request: Request, *args, **kwargs): + queryset = self.get_queryset_union_prefer() + + if request.query_params.get('search'): + # 如果用户搜索的条件不精准,会导致返回大量的无意义数据。 + # 这里限制一下返回数据的最大条数 + queryset = queryset[:999] + data = self.serialize_assets(queryset, None) + return Response(data=data) + + # def get_serializer_class(self): + # return EmptySerializer diff --git a/apps/perms/api/asset/user_permission/user_permission_assets/views.py b/apps/perms/api/asset/user_permission/user_permission_assets/views.py new file mode 100644 index 000000000..05b09442a --- /dev/null +++ b/apps/perms/api/asset/user_permission/user_permission_assets/views.py @@ -0,0 +1,99 @@ +from rest_framework.generics import ListAPIView +from django.conf import settings + +from common.utils import get_logger +from ..mixin import RoleAdminMixin, RoleUserMixin +from .mixin import ( + UserAllGrantedAssetsQuerysetMixin, UserDirectGrantedAssetsQuerysetMixin, UserFavoriteGrantedAssetsMixin, + UserGrantedNodeAssetsMixin, AssetsSerializerFormatMixin, AssetsTreeFormatMixin, +) + +__all__ = [ + 'UserDirectGrantedAssetsForAdminApi', 'MyDirectGrantedAssetsApi', 'UserFavoriteGrantedAssetsForAdminApi', + 'MyFavoriteGrantedAssetsApi', 'UserDirectGrantedAssetsAsTreeForAdminApi', 'MyUngroupAssetsAsTreeApi', + 'UserAllGrantedAssetsApi', 'MyAllGrantedAssetsApi', 'MyAllAssetsAsTreeApi', 'UserGrantedNodeAssetsForAdminApi', + 'MyGrantedNodeAssetsApi', +] + +logger = get_logger(__name__) + + +class UserDirectGrantedAssetsForAdminApi(UserDirectGrantedAssetsQuerysetMixin, + RoleAdminMixin, + AssetsSerializerFormatMixin, + ListAPIView): + pass + + +class MyDirectGrantedAssetsApi(UserDirectGrantedAssetsQuerysetMixin, + RoleUserMixin, + AssetsSerializerFormatMixin, + ListAPIView): + pass + + +class UserFavoriteGrantedAssetsForAdminApi(UserFavoriteGrantedAssetsMixin, + RoleAdminMixin, + AssetsSerializerFormatMixin, + ListAPIView): + pass + + +class MyFavoriteGrantedAssetsApi(UserFavoriteGrantedAssetsMixin, + RoleUserMixin, + AssetsSerializerFormatMixin, + ListAPIView): + pass + + +class UserDirectGrantedAssetsAsTreeForAdminApi(UserDirectGrantedAssetsQuerysetMixin, + RoleAdminMixin, + AssetsTreeFormatMixin, + ListAPIView): + pass + + +class MyUngroupAssetsAsTreeApi(UserDirectGrantedAssetsQuerysetMixin, + RoleUserMixin, + AssetsTreeFormatMixin, + ListAPIView): + def get_queryset(self): + queryset = super().get_queryset() + if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: + queryset = queryset.none() + return queryset + + +class UserAllGrantedAssetsApi(UserAllGrantedAssetsQuerysetMixin, + RoleAdminMixin, + AssetsSerializerFormatMixin, + ListAPIView): + pass + + +class MyAllGrantedAssetsApi(UserAllGrantedAssetsQuerysetMixin, + RoleUserMixin, + AssetsSerializerFormatMixin, + ListAPIView): + pass + + +class MyAllAssetsAsTreeApi(UserAllGrantedAssetsQuerysetMixin, + RoleUserMixin, + AssetsTreeFormatMixin, + ListAPIView): + search_fields = ['hostname', 'ip'] + + +class UserGrantedNodeAssetsForAdminApi(UserGrantedNodeAssetsMixin, + RoleAdminMixin, + AssetsSerializerFormatMixin, + ListAPIView): + pass + + +class MyGrantedNodeAssetsApi(UserGrantedNodeAssetsMixin, + RoleUserMixin, + AssetsSerializerFormatMixin, + ListAPIView): + pass diff --git a/apps/perms/api/asset/user_permission/user_permission_nodes.py b/apps/perms/api/asset/user_permission/user_permission_nodes.py index 58af37090..2a5cfacf2 100644 --- a/apps/perms/api/asset/user_permission/user_permission_nodes.py +++ b/apps/perms/api/asset/user_permission/user_permission_nodes.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- # import abc -from django.conf import settings from rest_framework.generics import ( ListAPIView ) @@ -10,16 +9,11 @@ from rest_framework.request import Request from assets.api.mixin import SerializeToTreeNodeMixin from common.utils import get_logger -from .mixin import ForAdminMixin, ForUserMixin, UserNodeGrantStatusDispatchMixin -from perms.hands import Node, User +from .mixin import RoleAdminMixin, RoleUserMixin +from perms.hands import User from perms import serializers -from perms.utils.asset.user_permission import ( - get_indirect_granted_node_children, - get_user_granted_nodes_list_via_mapping_node, - get_top_level_granted_nodes, - rebuild_user_tree_if_need, get_favorite_node, - get_ungrouped_node -) + +from perms.utils.asset.user_permission import UserGrantedNodesQueryUtils logger = get_logger(__name__) @@ -61,7 +55,6 @@ class BaseGrantedNodeApi(_GrantedNodeStructApi, metaclass=abc.ABCMeta): serializer_class = serializers.NodeGrantedSerializer def list(self, request, *args, **kwargs): - rebuild_user_tree_if_need(request, self.user) nodes = self.get_nodes() serializer = self.get_serializer(nodes, many=True) return Response(serializer.data) @@ -73,7 +66,6 @@ class BaseNodeChildrenApi(NodeChildrenMixin, BaseGrantedNodeApi, metaclass=abc.A class BaseGrantedNodeAsTreeApi(SerializeToTreeNodeMixin, _GrantedNodeStructApi, metaclass=abc.ABCMeta): def list(self, request: Request, *args, **kwargs): - rebuild_user_tree_if_need(request, self.user) nodes = self.get_nodes() nodes = self.serialize_nodes(nodes, with_asset_amount=True) return Response(data=nodes) @@ -83,30 +75,16 @@ class BaseNodeChildrenAsTreeApi(NodeChildrenMixin, BaseGrantedNodeAsTreeApi, met pass -class UserGrantedNodeChildrenMixin(UserNodeGrantStatusDispatchMixin): +class UserGrantedNodeChildrenMixin: user: User request: Request def get_children(self): user = self.user key = self.request.query_params.get('key') - - if not key: - nodes = list(get_top_level_granted_nodes(user)) - else: - nodes = self.dispatch_get_data(key, user) + nodes = UserGrantedNodesQueryUtils(user).get_node_children(key) return nodes - def get_data_on_node_direct_granted(self, key): - return Node.objects.filter(parent_key=key) - - def get_data_on_node_indirect_granted(self, key): - nodes = get_indirect_granted_node_children(self.user, key) - return nodes - - def get_data_on_node_not_granted(self, key): - return Node.objects.none() - class UserGrantedNodesMixin: """ @@ -115,41 +93,38 @@ class UserGrantedNodesMixin: user: User def get_nodes(self): - nodes = [] - if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: - nodes.append(get_ungrouped_node(self.user)) - nodes.append(get_favorite_node(self.user)) - nodes.extend(get_user_granted_nodes_list_via_mapping_node(self.user)) + utils = UserGrantedNodesQueryUtils(self.user) + nodes = utils.get_whole_tree_nodes() return nodes # ------------------------------------------ # 最终的 api -class UserGrantedNodeChildrenForAdminApi(ForAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi): +class UserGrantedNodeChildrenForAdminApi(RoleAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi): pass -class MyGrantedNodeChildrenApi(ForUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi): +class MyGrantedNodeChildrenApi(RoleUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi): pass -class UserGrantedNodeChildrenAsTreeForAdminApi(ForAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi): +class UserGrantedNodeChildrenAsTreeForAdminApi(RoleAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi): pass -class MyGrantedNodeChildrenAsTreeApi(ForUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi): +class MyGrantedNodeChildrenAsTreeApi(RoleUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi): pass -class UserGrantedNodesForAdminApi(ForAdminMixin, UserGrantedNodesMixin, BaseGrantedNodeApi): +class UserGrantedNodesForAdminApi(RoleAdminMixin, UserGrantedNodesMixin, BaseGrantedNodeApi): pass -class MyGrantedNodesApi(ForUserMixin, UserGrantedNodesMixin, BaseGrantedNodeApi): +class MyGrantedNodesApi(RoleUserMixin, UserGrantedNodesMixin, BaseGrantedNodeApi): pass -class MyGrantedNodesAsTreeApi(ForUserMixin, UserGrantedNodesMixin, BaseGrantedNodeAsTreeApi): +class MyGrantedNodesAsTreeApi(RoleUserMixin, UserGrantedNodesMixin, BaseGrantedNodeAsTreeApi): pass # ------------------------------------------ diff --git a/apps/perms/api/asset/user_permission/user_permission_nodes_with_assets.py b/apps/perms/api/asset/user_permission/user_permission_nodes_with_assets.py index 44d8f22b3..253a925ca 100644 --- a/apps/perms/api/asset/user_permission/user_permission_nodes_with_assets.py +++ b/apps/perms/api/asset/user_permission/user_permission_nodes_with_assets.py @@ -1,29 +1,23 @@ # -*- coding: utf-8 -*- # -from itertools import chain - from rest_framework.generics import ListAPIView from rest_framework.request import Request from rest_framework.response import Response -from django.db.models import F, Value, CharField, Q +from django.db.models import F, Value, CharField from django.conf import settings +from common.utils.common import timeit from orgs.utils import tmp_to_root_org from common.permissions import IsValidUser from common.utils import get_logger, get_object_or_none -from .mixin import UserNodeGrantStatusDispatchMixin, ForUserMixin, ForAdminMixin +from .mixin import RoleUserMixin, RoleAdminMixin from perms.utils.asset.user_permission import ( - get_indirect_granted_node_children, UNGROUPED_NODE_KEY, FAVORITE_NODE_KEY, - get_user_direct_granted_assets, get_top_level_granted_nodes, - get_user_granted_nodes_list_via_mapping_node, - get_user_granted_all_assets, rebuild_user_tree_if_need, - get_user_all_assetpermissions_id, get_favorite_node, - get_ungrouped_node, compute_tmp_mapping_node_from_perm, - TMP_GRANTED_FIELD, count_direct_granted_node_assets, - count_node_all_granted_assets + UserGrantedTreeBuildUtils, get_user_all_asset_perm_ids, + UserGrantedNodesQueryUtils, UserGrantedAssetsQueryUtils, + QuerySetStage, ) -from perms.models import AssetPermission -from assets.models import Asset, FavoriteAsset +from perms.models import AssetPermission, PermNode +from assets.models import Asset from assets.api import SerializeToTreeNodeMixin from perms.hands import Node @@ -33,76 +27,45 @@ logger = get_logger(__name__) class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView): permission_classes = (IsValidUser,) - def add_ungrouped_resource(self, data: list, user, asset_perms_id): + @timeit + def add_ungrouped_resource(self, data: list, nodes_query_utils, assets_query_utils): if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: return + ungrouped_node = nodes_query_utils.get_ungrouped_node() - ungrouped_node = get_ungrouped_node(user, asset_perms_id=asset_perms_id) - direct_granted_assets = get_user_direct_granted_assets( - user, asset_perms_id=asset_perms_id - ).annotate( + direct_granted_assets = assets_query_utils.get_direct_granted_assets().annotate( parent_key=Value(ungrouped_node.key, output_field=CharField()) ).prefetch_related('platform') data.extend(self.serialize_nodes([ungrouped_node], with_asset_amount=True)) data.extend(self.serialize_assets(direct_granted_assets)) - def add_favorite_resource(self, data: list, user, asset_perms_id): - favorite_node = get_favorite_node(user, asset_perms_id) - favorite_assets = FavoriteAsset.get_user_favorite_assets( - user, asset_perms_id=asset_perms_id - ).annotate( + @timeit + def add_favorite_resource(self, data: list, nodes_query_utils, assets_query_utils): + favorite_node = nodes_query_utils.get_favorite_node() + + qs_state = QuerySetStage().annotate( parent_key=Value(favorite_node.key, output_field=CharField()) ).prefetch_related('platform') + favorite_assets = assets_query_utils.get_favorite_assets(qs_stage=qs_state, only=()) data.extend(self.serialize_nodes([favorite_node], with_asset_amount=True)) data.extend(self.serialize_assets(favorite_assets)) + @timeit def add_node_filtered_by_system_user(self, data: list, user, asset_perms_id): - tmp_nodes = compute_tmp_mapping_node_from_perm(user, asset_perms_id=asset_perms_id) - granted_nodes_key = [] - for _node in tmp_nodes: - _granted = getattr(_node, TMP_GRANTED_FIELD, False) - if not _granted: - if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: - assets_amount = count_direct_granted_node_assets(user, _node.key, asset_perms_id) - else: - assets_amount = count_node_all_granted_assets(user, _node.key, asset_perms_id) - _node.assets_amount = assets_amount - else: - granted_nodes_key.append(_node.key) + utils = UserGrantedTreeBuildUtils(user, asset_perms_id) + nodes = utils.get_whole_tree_nodes() + data.extend(self.serialize_nodes(nodes, with_asset_amount=True)) - # 查询他们的子节点 - q = Q() - for _key in granted_nodes_key: - q |= Q(key__startswith=f'{_key}:') + def add_assets(self, data: list, assets_query_utils: UserGrantedAssetsQueryUtils): + qs_stage = QuerySetStage().annotate(parent_key=F('nodes__key')).prefetch_related('platform') - if q: - descendant_nodes = Node.objects.filter(q).distinct() - else: - descendant_nodes = Node.objects.none() - - data.extend(self.serialize_nodes(chain(tmp_nodes, descendant_nodes), with_asset_amount=True)) - - def add_assets(self, data: list, user, asset_perms_id): if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: - all_assets = get_user_granted_all_assets( - user, - via_mapping_node=False, - include_direct_granted_assets=False, - asset_perms_id=asset_perms_id - ) + all_assets = assets_query_utils.get_direct_granted_nodes_assets(qs_stage=qs_stage) else: - all_assets = get_user_granted_all_assets( - user, - via_mapping_node=False, - include_direct_granted_assets=True, - asset_perms_id=asset_perms_id - ) + all_assets = assets_query_utils.get_all_granted_assets(qs_stage=qs_stage) - all_assets = all_assets.annotate( - parent_key=F('nodes__key') - ).prefetch_related('platform') data.extend(self.serialize_assets(all_assets)) @tmp_to_root_org() @@ -117,7 +80,7 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView): user = request.user data = [] - asset_perms_id = get_user_all_assetpermissions_id(user) + asset_perms_id = get_user_all_asset_perm_ids(user) system_user_id = request.query_params.get('system_user') if system_user_id: @@ -125,89 +88,72 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView): id__in=asset_perms_id, system_users__id=system_user_id, actions__gt=0 ).values_list('id', flat=True).distinct()) - self.add_ungrouped_resource(data, user, asset_perms_id) - self.add_favorite_resource(data, user, asset_perms_id) + nodes_query_utils = UserGrantedNodesQueryUtils(user, asset_perms_id) + assets_query_utils = UserGrantedAssetsQueryUtils(user, asset_perms_id) + + self.add_ungrouped_resource(data, nodes_query_utils, assets_query_utils) + self.add_favorite_resource(data, nodes_query_utils, assets_query_utils) if system_user_id: + # 有系统用户筛选的需要重新计算树结构 self.add_node_filtered_by_system_user(data, user, asset_perms_id) else: - rebuild_user_tree_if_need(request, user) - all_nodes = get_user_granted_nodes_list_via_mapping_node(user) + all_nodes = nodes_query_utils.get_whole_tree_nodes(with_special=False) data.extend(self.serialize_nodes(all_nodes, with_asset_amount=True)) - self.add_assets(data, user, asset_perms_id) + self.add_assets(data, assets_query_utils) return Response(data=data) -class GrantedNodeChildrenWithAssetsAsTreeApiMixin(UserNodeGrantStatusDispatchMixin, - SerializeToTreeNodeMixin, +class GrantedNodeChildrenWithAssetsAsTreeApiMixin(SerializeToTreeNodeMixin, ListAPIView): """ 带资产的授权树 """ user: None - def get_data_on_node_direct_granted(self, key): - nodes = Node.objects.filter(parent_key=key) - assets = Asset.org_objects.filter(nodes__key=key).distinct() - assets = assets.prefetch_related('platform') - return nodes, assets + def ensure_key(self): + key = self.request.query_params.get('key', None) + id = self.request.query_params.get('id', None) - def get_data_on_node_indirect_granted(self, key): - user = self.user - asset_perms_id = get_user_all_assetpermissions_id(user) + if key is not None: + return key - nodes = get_indirect_granted_node_children(user, key) - - assets = Asset.org_objects.filter( - nodes__key=key, - ).filter( - granted_by_permissions__id__in=asset_perms_id - ).distinct() - assets = assets.prefetch_related('platform') - return nodes, assets - - def get_data_on_node_not_granted(self, key): - return Node.objects.none(), Asset.objects.none() - - def get_data(self, key, user): - assets, nodes = [], [] - if not key: - root_nodes = get_top_level_granted_nodes(user) - nodes.extend(root_nodes) - elif key == UNGROUPED_NODE_KEY: - assets = get_user_direct_granted_assets(user) - assets = assets.prefetch_related('platform') - elif key == FAVORITE_NODE_KEY: - assets = FavoriteAsset.get_user_favorite_assets(user) - else: - nodes, assets = self.dispatch_get_data(key, user) - return nodes, assets - - def id2key_if_have(self): - id = self.request.query_params.get('id') - if id is not None: - node = get_object_or_none(Node, id=id) - if node: - return node.key + node = get_object_or_none(Node, id=id) + if node: + return node.key def list(self, request: Request, *args, **kwargs): - key = self.request.query_params.get('key') - if key is None: - key = self.id2key_if_have() + user = self.user + key = self.ensure_key() + + nodes_query_utils = UserGrantedNodesQueryUtils(user) + assets_query_utils = UserGrantedAssetsQueryUtils(user) + + nodes = PermNode.objects.none() + assets = Asset.objects.none() + + if not key: + nodes = nodes_query_utils.get_top_level_nodes() + elif key == PermNode.UNGROUPED_NODE_KEY: + assets = assets_query_utils.get_ungroup_assets() + elif key == PermNode.FAVORITE_NODE_KEY: + assets = assets_query_utils.get_favorite_assets() + else: + nodes = nodes_query_utils.get_node_children(key) + assets = assets_query_utils.get_node_assets(key) + assets = assets.prefetch_related('platform') user = self.user - rebuild_user_tree_if_need(request, user) - nodes, assets = self.get_data(key, user) tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True) tree_assets = self.serialize_assets(assets, key) return Response(data=[*tree_nodes, *tree_assets]) -class UserGrantedNodeChildrenWithAssetsAsTreeApi(ForAdminMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin): +class UserGrantedNodeChildrenWithAssetsAsTreeApi(RoleAdminMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin): pass -class MyGrantedNodeChildrenWithAssetsAsTreeApi(ForUserMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin): +class MyGrantedNodeChildrenWithAssetsAsTreeApi(RoleUserMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin): pass diff --git a/apps/perms/api/system_user_permission.py b/apps/perms/api/system_user_permission.py index 0c026b54d..17ddfc786 100644 --- a/apps/perms/api/system_user_permission.py +++ b/apps/perms/api/system_user_permission.py @@ -1,10 +1,10 @@ from rest_framework import generics -from django.db.models import Q from django.utils.decorators import method_decorator from assets.models import SystemUser from common.permissions import IsValidUser from orgs.utils import tmp_to_root_org +from perms.utils.asset.user_permission import get_user_all_asset_perm_ids from .. import serializers @@ -16,9 +16,9 @@ class SystemUserPermission(generics.ListAPIView): def get_queryset(self): user = self.request.user + asset_perms_id = get_user_all_asset_perm_ids(user) queryset = SystemUser.objects.filter( - Q(granted_by_permissions__users=user) | - Q(granted_by_permissions__user_groups__users=user) + granted_by_permissions__id__in=asset_perms_id ).distinct() return queryset diff --git a/apps/perms/async_tasks/__init__.py b/apps/perms/async_tasks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/apps/perms/async_tasks/mapping_node_task.py b/apps/perms/async_tasks/mapping_node_task.py deleted file mode 100644 index c45527ab7..000000000 --- a/apps/perms/async_tasks/mapping_node_task.py +++ /dev/null @@ -1,47 +0,0 @@ -from django.utils.crypto import get_random_string -from perms.utils import rebuild_user_mapping_nodes_if_need_with_lock - -from common.thread_pools import SingletonThreadPoolExecutor -from common.utils import get_logger -from perms.models import RebuildUserTreeTask - -logger = get_logger(__name__) - - -class Executor(SingletonThreadPoolExecutor): - pass - - -executor = Executor() - - -def run_mapping_node_tasks(): - failed_user_ids = [] - - ident = get_random_string() - logger.debug(f'[{ident}]mapping_node_tasks running') - - while True: - task = RebuildUserTreeTask.objects.exclude( - user_id__in=failed_user_ids - ).first() - - if task is None: - break - - user = task.user - try: - rebuild_user_mapping_nodes_if_need_with_lock(user) - except: - logger.exception(f'[{ident}]mapping_node_tasks_exception') - failed_user_ids.append(user.id) - - logger.debug(f'[{ident}]mapping_node_tasks finished') - - -def submit_update_mapping_node_task(): - executor.submit(run_mapping_node_tasks) - - -def submit_update_mapping_node_task_for_user(user): - executor.submit(rebuild_user_mapping_nodes_if_need_with_lock, user) diff --git a/apps/perms/locks.py b/apps/perms/locks.py new file mode 100644 index 000000000..e1bd67f09 --- /dev/null +++ b/apps/perms/locks.py @@ -0,0 +1,11 @@ +from common.utils.lock import DistributedLock + + +class UserGrantedTreeRebuildLock(DistributedLock): + name_template = 'perms.user.asset.node.tree.rebuid..' + + def __init__(self, org_id, user_id): + name = self.name_template.format( + org_id=org_id, user_id=user_id + ) + super().__init__(name=name) diff --git a/apps/perms/migrations/0014_build_users_perm_tree.py b/apps/perms/migrations/0014_build_users_perm_tree.py index 85df89b35..a4fd97b96 100644 --- a/apps/perms/migrations/0014_build_users_perm_tree.py +++ b/apps/perms/migrations/0014_build_users_perm_tree.py @@ -1,19 +1,6 @@ # Generated by Django 2.2.13 on 2020-08-21 08:20 from django.db import migrations -from perms.tasks import dispatch_mapping_node_tasks - - -def start_build_users_perm_tree_task(apps, schema_editor): - User = apps.get_model('users', 'User') - RebuildUserTreeTask = apps.get_model('perms', 'RebuildUserTreeTask') - - user_ids = User.objects.all().values_list('id', flat=True).distinct() - RebuildUserTreeTask.objects.bulk_create( - [RebuildUserTreeTask(user_id=i) for i in user_ids] - ) - - dispatch_mapping_node_tasks.delay() class Migration(migrations.Migration): @@ -23,5 +10,4 @@ class Migration(migrations.Migration): ] operations = [ - migrations.RunPython(start_build_users_perm_tree_task) ] diff --git a/apps/perms/migrations/0018_auto_20210204_1749.py b/apps/perms/migrations/0018_auto_20210204_1749.py new file mode 100644 index 000000000..00d567c2f --- /dev/null +++ b/apps/perms/migrations/0018_auto_20210204_1749.py @@ -0,0 +1,65 @@ +# Generated by Django 3.1 on 2021-02-04 09:49 + +import assets.models.node +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('assets', '0066_remove_node_assets_amount'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('perms', '0017_auto_20210104_0435'), + ] + + operations = [ + migrations.CreateModel( + name='UserAssetGrantedTreeNodeRelation', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')), + ('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')), + ('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')), + ('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')), + ('org_id', models.CharField(blank=True, db_index=True, default='', max_length=36, verbose_name='Organization')), + ('node_key', models.CharField(db_index=True, max_length=64, verbose_name='Key')), + ('node_parent_key', models.CharField(db_index=True, default='', max_length=64, verbose_name='Parent key')), + ('node_from', models.CharField(choices=[('granted', 'Direct node granted'), ('child', 'Have children node'), ('asset', 'Direct asset granted')], db_index=True, max_length=16)), + ('node_assets_amount', models.IntegerField(default=0)), + ('node', models.ForeignKey(db_constraint=False, default=None, on_delete=django.db.models.deletion.CASCADE, related_name='granted_node_rels', to='assets.node')), + ('user', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'abstract': False, + }, + bases=(assets.models.node.FamilyMixin, models.Model), + ), + migrations.RemoveField( + model_name='usergrantedmappingnode', + name='node', + ), + migrations.RemoveField( + model_name='usergrantedmappingnode', + name='user', + ), + migrations.CreateModel( + name='PermNode', + fields=[ + ], + options={ + 'ordering': [], + 'proxy': True, + 'indexes': [], + 'constraints': [], + }, + bases=('assets.node',), + ), + migrations.DeleteModel( + name='RebuildUserTreeTask', + ), + migrations.DeleteModel( + name='UserGrantedMappingNode', + ), + ] diff --git a/apps/perms/models/asset_permission.py b/apps/perms/models/asset_permission.py index 00c0d5a97..e5a03c879 100644 --- a/apps/perms/models/asset_permission.py +++ b/apps/perms/models/asset_permission.py @@ -2,7 +2,10 @@ import logging from functools import reduce from django.utils.translation import ugettext_lazy as _ +from django.db.models import F +from common.db.models import ChoiceSet +from orgs.mixins.models import OrgModelMixin from common.db import models from common.utils import lazyproperty from assets.models import Asset, SystemUser, Node, FamilyMixin @@ -11,7 +14,7 @@ from .base import BasePermission __all__ = [ - 'AssetPermission', 'Action', 'UserGrantedMappingNode', 'RebuildUserTreeTask', + 'AssetPermission', 'Action', 'PermNode', 'UserAssetGrantedTreeNodeRelation', ] # 使用场景 @@ -135,39 +138,109 @@ class AssetPermission(BasePermission): from assets.models import Node nodes_keys = self.nodes.all().values_list('key', flat=True) assets_ids = set(self.assets.all().values_list('id', flat=True)) - nodes_assets_ids = Node.get_nodes_all_assets_ids(nodes_keys) + nodes_assets_ids = Node.get_nodes_all_assets_ids_by_keys(nodes_keys) assets_ids.update(nodes_assets_ids) assets = Asset.objects.filter(id__in=assets_ids) return assets +class UserAssetGrantedTreeNodeRelation(OrgModelMixin, FamilyMixin, models.JMSBaseModel): + class NodeFrom(ChoiceSet): + granted = 'granted', 'Direct node granted' + child = 'child', 'Have children node' + asset = 'asset', 'Direct asset granted' -class UserGrantedMappingNode(FamilyMixin, models.JMSBaseModel): - node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE, - db_constraint=False, null=True, related_name='mapping_nodes') - key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True) # '1:1:1:1' user = models.ForeignKey('users.User', db_constraint=False, on_delete=models.CASCADE) - granted = models.BooleanField(default=False, db_index=True) - asset_granted = models.BooleanField(default=False, db_index=True) - parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'), db_index=True) # '1:1:1:1' - assets_amount = models.IntegerField(default=0) + node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE, + db_constraint=False, null=False, related_name='granted_node_rels') + node_key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True) + node_parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'), db_index=True) + node_from = models.CharField(choices=NodeFrom.choices, max_length=16, db_index=True) + node_assets_amount = models.IntegerField(default=0) - GRANTED_DIRECT = 1 - GRANTED_INDIRECT = 2 - GRANTED_NONE = 0 + @property + def key(self): + return self.node_key + + @property + def parent_key(self): + return self.node_parent_key @classmethod - def get_node_granted_status(cls, key, user): - ancestor_keys = Node.get_node_ancestor_keys(key, with_self=True) - has_granted = UserGrantedMappingNode.objects.filter( - key__in=ancestor_keys, user=user - ).values_list('granted', flat=True) - if not has_granted: - return cls.GRANTED_NONE - if any(list(has_granted)): - return cls.GRANTED_DIRECT - return cls.GRANTED_INDIRECT + def get_node_granted_status(cls, user, key): + ancestor_keys = set(cls.get_node_ancestor_keys(key, with_self=True)) + ancestor_rel_nodes = cls.objects.filter(user=user, node_key__in=ancestor_keys) + + for rel_node in ancestor_rel_nodes: + if rel_node.key == key: + return rel_node.node_from, rel_node + if rel_node.node_from == cls.NodeFrom.granted: + return cls.NodeFrom.granted, None + return '', None -class RebuildUserTreeTask(models.JMSBaseModel): - user = models.ForeignKey('users.User', on_delete=models.CASCADE, verbose_name=_('User')) +class PermNode(Node): + class Meta: + proxy = True + ordering = [] + + # 特殊节点 + UNGROUPED_NODE_KEY = 'ungrouped' + UNGROUPED_NODE_VALUE = _('Ungrouped') + FAVORITE_NODE_KEY = 'favorite' + FAVORITE_NODE_VALUE = _('Favorite') + + node_from = '' + granted_assets_amount = 0 + + # 提供可以设置 资产数量的字段 + _assets_amount = None + + annotate_granted_node_rel_fields = { + 'granted_assets_amount': F('granted_node_rels__node_assets_amount'), + 'node_from': F('granted_node_rels__node_from') + } + + @property + def assets_amount(self): + _assets_amount = getattr(self, '_assets_amount') + if isinstance(_assets_amount, int): + return _assets_amount + return super().assets_amount + + @assets_amount.setter + def assets_amount(self, value): + self._assets_amount = value + + def use_granted_assets_amount(self): + self.assets_amount = self.granted_assets_amount + + @classmethod + def get_ungrouped_node(cls, assets_amount): + return cls( + id=cls.UNGROUPED_NODE_KEY, + key=cls.UNGROUPED_NODE_KEY, + value=cls.UNGROUPED_NODE_VALUE, + assets_amount=assets_amount + ) + + @classmethod + def get_favorite_node(cls, assets_amount): + node = cls( + id=cls.FAVORITE_NODE_KEY, + key=cls.FAVORITE_NODE_KEY, + value=cls.FAVORITE_NODE_VALUE, + ) + node.assets_amount = assets_amount + return node + + def get_granted_status(self, user): + status, rel_node = UserAssetGrantedTreeNodeRelation.get_node_granted_status(user, self.key) + self.node_from = status + if rel_node: + self.granted_assets_amount = rel_node.node_assets_amount + return status + + def save(self): + # 这是个只读 Model + raise NotImplementedError diff --git a/apps/perms/pagination.py b/apps/perms/pagination.py index 75cf6c493..fc5e43de7 100644 --- a/apps/perms/pagination.py +++ b/apps/perms/pagination.py @@ -1,30 +1,54 @@ from rest_framework.pagination import LimitOffsetPagination from rest_framework.request import Request +from django.db.models import Sum +from perms.models import UserAssetGrantedTreeNodeRelation from common.utils import get_logger logger = get_logger(__name__) -class GrantedAssetLimitOffsetPagination(LimitOffsetPagination): +class GrantedAssetPaginationBase(LimitOffsetPagination): + + def paginate_queryset(self, queryset, request: Request, view=None): + self._request = request + self._view = view + self._user = request.user + return super().paginate_queryset(queryset, request, view=None) + def get_count(self, queryset): exclude_query_params = { self.limit_query_param, self.offset_query_param, 'key', 'all', 'show_current_asset', - 'cache_policy', 'display', 'draw' + 'cache_policy', 'display', 'draw', + 'order', } for k, v in self._request.query_params.items(): if k not in exclude_query_params and v is not None: + logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}') return super().get_count(queryset) + return self.get_count_from_nodes(queryset) + + def get_count_from_nodes(self, queryset): + raise NotImplementedError + + +class NodeGrantedAssetPagination(GrantedAssetPaginationBase): + def get_count_from_nodes(self, queryset): node = getattr(self._view, 'pagination_node', None) if node: - logger.debug(f'{self._request.get_full_path()} hit node.assets_amount[{node.assets_amount}]') + logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}') return node.assets_amount else: + logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} not has `pagination_node` -> {self._request.get_full_path()}') return super().get_count(queryset) - def paginate_queryset(self, queryset, request: Request, view=None): - self._request = request - self._view = view - return super().paginate_queryset(queryset, request, view=None) + +class AllGrantedAssetPagination(GrantedAssetPaginationBase): + def get_count_from_nodes(self, queryset): + assets_amount = sum(UserAssetGrantedTreeNodeRelation.objects.filter( + user=self._user, node_parent_key='' + ).values_list('node_assets_amount', flat=True)) + logger.debug(f'Hit all assets amount {assets_amount} -> {self._request.get_full_path()}') + return assets_amount diff --git a/apps/perms/signals_handler/__init__.py b/apps/perms/signals_handler/__init__.py new file mode 100644 index 000000000..e0b84afea --- /dev/null +++ b/apps/perms/signals_handler/__init__.py @@ -0,0 +1,2 @@ +from . import common +from . import refresh_perms diff --git a/apps/perms/signals_handler.py b/apps/perms/signals_handler/common.py similarity index 71% rename from apps/perms/signals_handler.py rename to apps/perms/signals_handler/common.py index 9e0bfbaeb..c714de834 100644 --- a/apps/perms/signals_handler.py +++ b/apps/perms/signals_handler/common.py @@ -1,31 +1,22 @@ # -*- coding: utf-8 -*- # -from django.db.models.signals import m2m_changed, pre_delete, pre_save +from django.db.models.signals import m2m_changed from django.dispatch import receiver -from perms.tasks import create_rebuild_user_tree_task, \ - create_rebuild_user_tree_task_by_related_nodes_or_assets from users.models import User, UserGroup -from assets.models import Asset, SystemUser +from assets.models import SystemUser from applications.models import Application from common.utils import get_logger from common.exceptions import M2MReverseNotAllowed -from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR -from .models import AssetPermission, ApplicationPermission +from common.const.signals import POST_ADD +from perms.models import AssetPermission, ApplicationPermission logger = get_logger(__file__) -def handle_rebuild_user_tree(instance, action, reverse, pk_set, **kwargs): - if action.startswith('post'): - if reverse: - create_rebuild_user_tree_task(pk_set) - else: - create_rebuild_user_tree_task([instance.id]) - - -def handle_bind_groups_systemuser(instance, action, reverse, pk_set, **kwargs): +@receiver(m2m_changed, sender=User.groups.through) +def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs): """ UserGroup 增加 User 时,增加的 User 需要与 UserGroup 关联的动态系统用户相关联 """ @@ -47,53 +38,11 @@ def handle_bind_groups_systemuser(instance, action, reverse, pk_set, **kwargs): system_user.users.add(*users_id) -@receiver(m2m_changed, sender=User.groups.through) -def on_user_groups_change(**kwargs): - handle_rebuild_user_tree(**kwargs) - handle_bind_groups_systemuser(**kwargs) - - -@receiver([pre_save], sender=AssetPermission) -def on_asset_perm_deactive(instance: AssetPermission, **kwargs): - try: - old = AssetPermission.objects.only('is_active').get(id=instance.id) - if instance.is_active != old.is_active: - create_rebuild_user_tree_task_by_asset_perm(instance) - except AssetPermission.DoesNotExist: - pass - - -@receiver([pre_delete], sender=AssetPermission) -def on_asset_permission_delete(instance, **kwargs): - # 授权删除之前,查出所有相关用户 - create_rebuild_user_tree_task_by_asset_perm(instance) - - -def create_rebuild_user_tree_task_by_asset_perm(asset_perm: AssetPermission): - user_ids = set() - user_ids.update( - UserGroup.objects.filter( - assetpermissions=asset_perm, users__id__isnull=False - ).distinct().values_list('users__id', flat=True) - ) - user_ids.update( - User.objects.filter(assetpermissions=asset_perm).distinct().values_list('id', flat=True) - ) - create_rebuild_user_tree_task(user_ids) - - -def need_rebuild_mapping_node(action): - return action in (POST_REMOVE, POST_ADD, POST_CLEAR) - - @receiver(m2m_changed, sender=AssetPermission.nodes.through) def on_permission_nodes_changed(instance, action, reverse, pk_set, model, **kwargs): if reverse: raise M2MReverseNotAllowed - if need_rebuild_mapping_node(action): - create_rebuild_user_tree_task_by_asset_perm(instance) - if action != POST_ADD: return logger.debug("Asset permission nodes change signal received") @@ -110,9 +59,6 @@ def on_permission_assets_changed(instance, action, reverse, pk_set, model, **kwa if reverse: raise M2MReverseNotAllowed - if need_rebuild_mapping_node(action): - create_rebuild_user_tree_task_by_asset_perm(instance) - if action != POST_ADD: return logger.debug("Asset permission assets change signal received") @@ -150,9 +96,6 @@ def on_asset_permission_users_changed(instance, action, reverse, pk_set, model, if reverse: raise M2MReverseNotAllowed - if need_rebuild_mapping_node(action): - create_rebuild_user_tree_task(pk_set) - if action != POST_ADD: return logger.debug("Asset permission users change signal received") @@ -171,10 +114,6 @@ def on_asset_permission_user_groups_changed(instance, action, pk_set, model, if reverse: raise M2MReverseNotAllowed - if need_rebuild_mapping_node(action): - user_ids = User.objects.filter(groups__id__in=pk_set).distinct().values_list('id', flat=True) - create_rebuild_user_tree_task(user_ids) - if action != POST_ADD: return logger.debug("Asset permission user groups change signal received") @@ -187,21 +126,6 @@ def on_asset_permission_user_groups_changed(instance, action, pk_set, model, system_user.groups.add(*tuple(groups)) -@receiver(m2m_changed, sender=Asset.nodes.through) -def on_node_asset_change(action, instance, reverse, pk_set, **kwargs): - if not need_rebuild_mapping_node(action): - return - - if reverse: - asset_pk_set = pk_set - node_pk_set = [instance.id] - else: - asset_pk_set = [instance.id] - node_pk_set = pk_set - - create_rebuild_user_tree_task_by_related_nodes_or_assets.delay(node_pk_set, asset_pk_set) - - @receiver(m2m_changed, sender=ApplicationPermission.system_users.through) def on_application_permission_system_users_changed(sender, instance: ApplicationPermission, action, reverse, pk_set, **kwargs): if not instance.category_remote_app: diff --git a/apps/perms/signals_handler/refresh_perms.py b/apps/perms/signals_handler/refresh_perms.py new file mode 100644 index 000000000..d91594444 --- /dev/null +++ b/apps/perms/signals_handler/refresh_perms.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# +from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save +from django.dispatch import receiver + +from users.models import User +from assets.models import Asset +from orgs.utils import current_org +from common.utils import get_logger +from common.exceptions import M2MReverseNotAllowed +from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR +from perms.models import AssetPermission +from perms.utils.asset.user_permission import UserGrantedTreeRefreshController + + +logger = get_logger(__file__) + + +@receiver(m2m_changed, sender=User.groups.through) +def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs): + if action.startswith('post'): + if reverse: + group_ids = [instance.id] + user_ids = pk_set + else: + group_ids = pk_set + user_ids = [instance.id] + + exists = AssetPermission.user_groups.through.objects.filter(usergroup_id__in=group_ids).exists() + if exists: + org_ids = [current_org.id] + UserGrantedTreeRefreshController.add_need_refresh_orgs_for_users(org_ids, user_ids) + + +@receiver([pre_delete], sender=AssetPermission) +def on_asset_perm_pre_delete(sender, instance, **kwargs): + # 授权删除之前,查出所有相关用户 + UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id]) + + +@receiver([pre_save], sender=AssetPermission) +def on_asset_perm_pre_save(sender, instance, **kwargs): + try: + old = AssetPermission.objects.get(id=instance.id) + + if old.is_valid != instance.is_valid: + UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id]) + except AssetPermission.DoesNotExist: + pass + + +@receiver([post_save], sender=AssetPermission) +def on_asset_perm_post_save(sender, instance, created, **kwargs): + if created: + UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id]) + + +def need_rebuild_mapping_node(action): + return action in (POST_REMOVE, POST_ADD, POST_CLEAR) + + +@receiver(m2m_changed, sender=AssetPermission.nodes.through) +def on_permission_nodes_changed(sender, instance, action, reverse, **kwargs): + if reverse: + raise M2MReverseNotAllowed + + if need_rebuild_mapping_node(action): + UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id]) + + +@receiver(m2m_changed, sender=AssetPermission.assets.through) +def on_permission_assets_changed(sender, instance, action, reverse, pk_set, model, **kwargs): + if reverse: + raise M2MReverseNotAllowed + + if need_rebuild_mapping_node(action): + UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id]) + + +@receiver(m2m_changed, sender=AssetPermission.users.through) +def on_asset_permission_users_changed(sender, action, reverse, pk_set, **kwargs): + if reverse: + raise M2MReverseNotAllowed + + if need_rebuild_mapping_node(action): + UserGrantedTreeRefreshController.add_need_refresh_orgs_for_users( + [current_org.id], pk_set + ) + + +@receiver(m2m_changed, sender=AssetPermission.user_groups.through) +def on_asset_permission_user_groups_changed(sender, action, pk_set, reverse, **kwargs): + if reverse: + raise M2MReverseNotAllowed + + if need_rebuild_mapping_node(action): + user_ids = User.groups.through.objects.filter(usergroup_id__in=pk_set).distinct().values_list('user_id', flat=True) + UserGrantedTreeRefreshController.add_need_refresh_orgs_for_users( + [current_org.id], user_ids + ) + + +@receiver(m2m_changed, sender=Asset.nodes.through) +def on_node_asset_change(action, instance, reverse, pk_set, **kwargs): + if not need_rebuild_mapping_node(action): + return + + if reverse: + asset_pk_set = pk_set + node_pk_set = [instance.id] + else: + asset_pk_set = [instance.id] + node_pk_set = pk_set + + UserGrantedTreeRefreshController.add_need_refresh_on_nodes_assets_relate_change(node_pk_set, asset_pk_set) diff --git a/apps/perms/tasks.py b/apps/perms/tasks.py index fbf2ce8be..8c796ad27 100644 --- a/apps/perms/tasks.py +++ b/apps/perms/tasks.py @@ -2,39 +2,18 @@ from __future__ import absolute_import, unicode_literals from datetime import timedelta -from django.db import transaction -from django.db.models import Q from django.db.transaction import atomic from django.conf import settings from celery import shared_task from common.utils import get_logger from common.utils.timezone import now, dt_formater, dt_parser -from users.models import User from ops.celery.decorator import register_as_period_task -from assets.models import Node -from perms.models import RebuildUserTreeTask, AssetPermission -from perms.utils.asset.user_permission import rebuild_user_mapping_nodes_if_need_with_lock, lock +from perms.models import AssetPermission +from perms.utils.asset.user_permission import UserGrantedTreeRefreshController logger = get_logger(__file__) -@shared_task(queue='node_tree') -def rebuild_user_mapping_nodes_celery_task(user_id): - user = User.objects.get(id=user_id) - try: - rebuild_user_mapping_nodes_if_need_with_lock(user) - except lock.SomeoneIsDoingThis: - pass - - -@shared_task(queue='node_tree') -def dispatch_mapping_node_tasks(): - user_ids = RebuildUserTreeTask.objects.all().values_list('user_id', flat=True).distinct() - logger.info(f'>>> dispatch_mapping_node_tasks for users {list(user_ids)}') - for id in user_ids: - rebuild_user_mapping_nodes_celery_task.delay(id) - - @register_as_period_task(interval=settings.PERM_EXPIRED_CHECK_PERIODIC) @shared_task(queue='celery_check_asset_perm_expired') @atomic() @@ -60,66 +39,9 @@ def check_asset_permission_expired(): setting.value = dt_formater(end) setting.save() - ids = AssetPermission.objects.filter( + asset_perm_ids = AssetPermission.objects.filter( date_expired__gte=start, date_expired__lte=end ).distinct().values_list('id', flat=True) - logger.info(f'>>> checking {start} to {end} have {ids} expired') - dispatch_process_expired_asset_permission.delay(list(ids)) - - -@shared_task(queue='node_tree') -def dispatch_process_expired_asset_permission(asset_perms_id): - user_ids = User.objects.filter( - Q(assetpermissions__id__in=asset_perms_id) | - Q(groups__assetpermissions__id__in=asset_perms_id) - ).distinct().values_list('id', flat=True) - RebuildUserTreeTask.objects.bulk_create( - [RebuildUserTreeTask(user_id=user_id) for user_id in user_ids] - ) - - dispatch_mapping_node_tasks.delay() - - -def create_rebuild_user_tree_task(user_ids): - RebuildUserTreeTask.objects.bulk_create( - [RebuildUserTreeTask(user_id=i) for i in user_ids] - ) - transaction.on_commit(dispatch_mapping_node_tasks.delay) - - -@shared_task(queue='node_tree') -def create_rebuild_user_tree_task_by_related_nodes_or_assets(node_ids, asset_ids): - node_ids = set(node_ids) - node_keys = set() - nodes = Node.objects.filter(id__in=node_ids) - for _node in nodes: - node_keys.update(_node.get_ancestor_keys()) - node_ids.update( - Node.objects.filter(key__in=node_keys).values_list('id', flat=True) - ) - - asset_perms_id = set() - asset_perms_id.update( - AssetPermission.objects.filter( - assets__id__in=asset_ids - ).values_list('id', flat=True).distinct() - ) - asset_perms_id.update( - AssetPermission.objects.filter( - nodes__id__in=node_ids - ).values_list('id', flat=True).distinct() - ) - - user_ids = set() - user_ids.update( - User.objects.filter( - assetpermissions__id__in=asset_perms_id - ).distinct().values_list('id', flat=True) - ) - user_ids.update( - User.objects.filter( - groups__assetpermissions__id__in=asset_perms_id - ).distinct().values_list('id', flat=True) - ) - - create_rebuild_user_tree_task(user_ids) + asset_perm_ids = list(asset_perm_ids) + logger.info(f'>>> checking {start} to {end} have {asset_perm_ids} expired') + UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids_cross_orgs(asset_perm_ids) diff --git a/apps/perms/utils/asset/user_permission.py b/apps/perms/utils/asset/user_permission.py index 5b9836400..6319f7f4c 100644 --- a/apps/perms/utils/asset/user_permission.py +++ b/apps/perms/utils/asset/user_permission.py @@ -1,524 +1,748 @@ -from functools import reduce, wraps -from operator import or_, and_ -from uuid import uuid4 -import threading -import inspect +from collections import defaultdict +from typing import List, Tuple +from django.core.cache import cache from django.conf import settings -from django.db.models import F, Q, Value, BooleanField -from django.utils.translation import ugettext_lazy as _ +from django.db.models import Q, QuerySet -from common.http import is_true +from common.db.models import output_as_string +from common.utils.common import lazyproperty, timeit, Time +from assets.utils import NodeAssetsUtil from common.utils import get_logger -from common.const.distributed_lock_key import UPDATE_MAPPING_NODE_TASK_LOCK_KEY -from orgs.utils import tmp_to_root_org -from common.utils.timezone import dt_formater, now -from assets.models import Node, Asset, FavoriteAsset -from django.db.transaction import atomic -from orgs import lock -from perms.models import UserGrantedMappingNode, RebuildUserTreeTask, AssetPermission +from common.decorator import on_transaction_commit +from orgs.utils import tmp_to_org, current_org, ensure_in_real_or_default_org +from assets.models import ( + Asset, FavoriteAsset, AssetQuerySet, NodeQuerySet +) +from orgs.models import Organization +from perms.models import ( + AssetPermission, PermNode, UserAssetGrantedTreeNodeRelation, +) from users.models import User +from perms.locks import UserGrantedTreeRebuildLock + +NodeFrom = UserAssetGrantedTreeNodeRelation.NodeFrom +NODE_ONLY_FIELDS = ('id', 'key', 'parent_key', 'org_id') logger = get_logger(__name__) -ADD = 'add' -REMOVE = 'remove' -UNGROUPED_NODE_KEY = 'ungrouped' -UNGROUPED_NODE_VALUE = _('Ungrouped') -FAVORITE_NODE_KEY = 'favorite' -FAVORITE_NODE_VALUE = _('Favorite') +def get_user_all_asset_perm_ids(user) -> set: + asset_perm_ids = set() + user_perm_id = AssetPermission.users.through.objects\ + .filter(user_id=user.id) \ + .values_list('assetpermission_id', flat=True) \ + .distinct() + asset_perm_ids.update(user_perm_id) -TMP_GRANTED_FIELD = '_granted' -TMP_ASSET_GRANTED_FIELD = '_asset_granted' -TMP_GRANTED_ASSETS_AMOUNT_FIELD = '_granted_assets_amount' + group_ids = user.groups.through.objects \ + .filter(user_id=user.id) \ + .values_list('usergroup_id', flat=True) \ + .distinct() + group_ids = list(group_ids) + groups_perm_id = AssetPermission.user_groups.through.objects\ + .filter(usergroup_id__in=group_ids)\ + .values_list('assetpermission_id', flat=True) \ + .distinct() + asset_perm_ids.update(groups_perm_id) + + asset_perm_ids = AssetPermission.objects.filter( + id__in=asset_perm_ids).valid().values_list('id', flat=True) + return asset_perm_ids -# 使用场景 -# `Node.objects.annotate(**node_annotate_mapping_node)` -node_annotate_mapping_node = { - TMP_GRANTED_FIELD: F('mapping_nodes__granted'), - TMP_ASSET_GRANTED_FIELD: F('mapping_nodes__asset_granted'), - TMP_GRANTED_ASSETS_AMOUNT_FIELD: F('mapping_nodes__assets_amount') -} +class QuerySetStage: + def __init__(self): + self._prefetch_related = set() + self._only = () + self._filters = [] + self._querysets_and = [] + self._querysets_or = [] + self._order_by = None + self._annotate = [] + self._before_union_merge_funs = set() + self._after_union_merge_funs = set() + + def annotate(self, *args, **kwargs): + self._annotate.append((args, kwargs)) + self._before_union_merge_funs.add(self._merge_annotate) + return self + + def prefetch_related(self, *lookups): + self._prefetch_related.update(lookups) + self._before_union_merge_funs.add(self._merge_prefetch_related) + return self + + def only(self, *fields): + self._only = fields + self._before_union_merge_funs.add(self._merge_only) + return self + + def order_by(self, *field_names): + self._order_by = field_names + self._after_union_merge_funs.add(self._merge_order_by) + return self + + def filter(self, *args, **kwargs): + self._filters.append((args, kwargs)) + self._before_union_merge_funs.add(self._merge_filters) + return self + + def and_with_queryset(self, qs: QuerySet): + assert isinstance(qs, QuerySet), f'Must be `QuerySet`' + self._order_by = qs.query.order_by + self._after_union_merge_funs.add(self._merge_order_by) + self._querysets_and.append(qs.order_by()) + self._before_union_merge_funs.add(self._merge_querysets_and) + return self + + def or_with_queryset(self, qs: QuerySet): + assert isinstance(qs, QuerySet), f'Must be `QuerySet`' + self._order_by = qs.query.order_by + self._after_union_merge_funs.add(self._merge_order_by) + self._querysets_or.append(qs.order_by()) + self._before_union_merge_funs.add(self._merge_querysets_or) + return self + + def merge_multi_before_union(self, *querysets): + ret = [] + for qs in querysets: + qs = self.merge_before_union(qs) + ret.append(qs) + return ret + + def _merge_only(self, qs: QuerySet): + if self._only: + qs = qs.only(*self._only) + return qs + + def _merge_filters(self, qs: QuerySet): + if self._filters: + for args, kwargs in self._filters: + qs = qs.filter(*args, **kwargs) + return qs + + def _merge_querysets_and(self, qs: QuerySet): + if self._querysets_and: + for qs_and in self._querysets_and: + qs &= qs_and + return qs + + def _merge_annotate(self, qs: QuerySet): + if self._annotate: + for args, kwargs in self._annotate: + qs = qs.annotate(*args, **kwargs) + return qs + + def _merge_querysets_or(self, qs: QuerySet): + if self._querysets_or: + for qs_or in self._querysets_or: + qs |= qs_or + return qs + + def _merge_prefetch_related(self, qs: QuerySet): + if self._prefetch_related: + qs = qs.prefetch_related(*self._prefetch_related) + return qs + + def _merge_order_by(self, qs: QuerySet): + if self._order_by is not None: + qs = qs.order_by(*self._order_by) + return qs + + def merge_before_union(self, qs: QuerySet) -> QuerySet: + assert isinstance(qs, QuerySet), f'Must be `QuerySet`' + for fun in self._before_union_merge_funs: + qs = fun(qs) + return qs + + def merge_after_union(self, qs: QuerySet) -> QuerySet: + for fun in self._after_union_merge_funs: + qs = fun(qs) + return qs + + def merge(self, qs: QuerySet) -> QuerySet: + qs = self.merge_before_union(qs) + qs = self.merge_after_union(qs) + return qs -# 使用场景 -# `Node.objects.annotate(**node_annotate_set_granted)` -node_annotate_set_granted = { - TMP_GRANTED_FIELD: Value(True, output_field=BooleanField()), -} +class UserGrantedTreeRefreshController: + key_template = 'perms.user.node_tree.builded_orgs.user_id:{user_id}' + def __init__(self, user): + self.user = user + self.key = self.key_template.format(user_id=user.id) + self.client = self.get_redis_client() -def is_direct_granted_by_annotate(node): - return getattr(node, TMP_GRANTED_FIELD, False) + @classmethod + def get_redis_client(cls): + return cache.client.get_client(write=True) + def get_need_refresh_org_ids(self): + org_ids = self.client.smembers(self.key) + return {org_id.decode() for org_id in org_ids} -def is_asset_granted(node): - return getattr(node, TMP_ASSET_GRANTED_FIELD, False) + def set_all_orgs_as_builed(self): + orgs_id = [str(org_id) for org_id in self.orgs_id] + self.client.sadd(self.key, *orgs_id) + def get_need_refresh_orgs_and_fill_up(self): + orgs_id = set(str(org_id) for org_id in self.orgs_id) -def get_granted_assets_amount(node): - return getattr(node, TMP_GRANTED_ASSETS_AMOUNT_FIELD, 0) + with self.client.pipeline() as p: + p.smembers(self.key) + p.sadd(self.key, *orgs_id) + ret = p.execute() + builded_orgs_id = {org_id.decode() for org_id in ret[0]} + ids = orgs_id - builded_orgs_id + orgs = [] + if Organization.DEFAULT_ID in ids: + ids.remove(Organization.DEFAULT_ID) + orgs.append(Organization.default()) + orgs.extend(Organization.objects.filter(id__in=ids)) + logger.info(f'Need rebuild orgs are {orgs}, builed orgs are {ret[0]}, all orgs are {orgs_id}') + return orgs + @classmethod + @on_transaction_commit + def remove_builed_orgs_from_users(cls, orgs_id, users_id): + client = cls.get_redis_client() + org_ids = [str(org_id) for org_id in orgs_id] -def set_granted(obj): - setattr(obj, TMP_GRANTED_FIELD, True) + with client.pipeline() as p: + for user_id in users_id: + key = cls.key_template.format(user_id=user_id) + p.srem(key, *org_ids) + p.execute() + logger.info(f'Remove orgs from users builded tree, users:{users_id} orgs:{orgs_id}') + @classmethod + def add_need_refresh_orgs_for_users(cls, orgs_id, users_id): + cls.remove_builed_orgs_from_users(orgs_id, users_id) -def set_asset_granted(obj): - setattr(obj, TMP_ASSET_GRANTED_FIELD, True) - - -VALUE_TEMPLATE = '{stage}:{rand_str}:thread:{thread_name}:{thread_id}:{now}' - - -def _generate_value(stage=lock.DOING): - cur_thread = threading.current_thread() - - return VALUE_TEMPLATE.format( - stage=stage, - thread_name=cur_thread.name, - thread_id=cur_thread.ident, - now=dt_formater(now()), - rand_str=uuid4() - ) - - -def build_user_mapping_node_lock(func): - @wraps(func) - def wrapper(*args, **kwargs): - call_args = inspect.getcallargs(func, *args, **kwargs) - user = call_args.get('user') - if user is None or not isinstance(user, User): - raise ValueError('You function must have `user` argument') - - key = UPDATE_MAPPING_NODE_TASK_LOCK_KEY.format(user_id=user.id) - doing_value = _generate_value() - commiting_value = _generate_value(stage=lock.COMMITING) - - try: - locked = lock.acquire(key, doing_value, timeout=600) - if not locked: - logger.error(f'update_mapping_node_task_locked_failed for user: {user.id}') - raise lock.SomeoneIsDoingThis - - with atomic(savepoint=False): - func(*args, **kwargs) - ok = lock.change_lock_state_to_commiting(key, doing_value, commiting_value) - if not ok: - logger.error(f'update_mapping_node_task_timeout for user: {user.id}') - raise lock.Timeout - finally: - lock.release(key, commiting_value, doing_value) - return wrapper - - -@build_user_mapping_node_lock -def rebuild_user_mapping_nodes_if_need_with_lock(user: User): - tasks = RebuildUserTreeTask.objects.filter(user=user) - if tasks: - tasks.delete() - rebuild_user_mapping_nodes(user) - - -@build_user_mapping_node_lock -def rebuild_user_mapping_nodes_with_lock(user: User): - rebuild_user_mapping_nodes(user) - - -def compute_tmp_mapping_node_from_perm(user: User, asset_perms_id=None): - node_only_fields = ('id', 'key', 'parent_key', 'assets_amount') - - if asset_perms_id is None: - asset_perms_id = get_user_all_assetpermissions_id(user) - - # 查询直接授权节点 - nodes = Node.objects.filter( - granted_by_permissions__id__in=asset_perms_id - ).distinct().only(*node_only_fields) - granted_key_set = {_node.key for _node in nodes} - - def _has_ancestor_granted(node): + @classmethod + def add_need_refresh_on_nodes_assets_relate_change(cls, node_ids, asset_ids): """ - 判断一个节点是否有授权过的祖先节点 + 1,计算与这些资产有关的授权 + 2,计算与这些节点以及祖先节点有关的授权 """ - ancestor_keys = set(node.get_ancestor_keys()) - return ancestor_keys & granted_key_set + ensure_in_real_or_default_org() - key2leaf_nodes_mapper = {} + node_ids = set(node_ids) + ancestor_node_keys = set() + asset_perm_ids = set() - # 给授权节点设置 _granted 标识,同时去重 - for _node in nodes: - if _has_ancestor_granted(_node): - continue + nodes = PermNode.objects.filter(id__in=node_ids).only('id', 'key') + for node in nodes: + ancestor_node_keys.update(node.get_ancestor_keys()) - if _node.key not in key2leaf_nodes_mapper: - set_granted(_node) - key2leaf_nodes_mapper[_node.key] = _node + ancestor_id = PermNode.objects.filter(key__in=ancestor_node_keys).values_list('id', flat=True) + node_ids.update(ancestor_id) - # 查询授权资产关联的节点设置 - def process_direct_granted_assets(): - # 查询直接授权资产 - asset_ids = Asset.objects.filter( - granted_by_permissions__id__in=asset_perms_id - ).distinct().values_list('id', flat=True) - # 查询授权资产关联的节点设置 - granted_asset_nodes = Node.objects.filter( - assets__id__in=asset_ids - ).distinct().only(*node_only_fields) + assets_related_perms_id = AssetPermission.nodes.through.objects.filter( + node_id__in=node_ids + ).values_list('assetpermission_id', flat=True) + asset_perm_ids.update(assets_related_perms_id) - # 给资产授权关联的节点设置 _asset_granted 标识,同时去重 - for _node in granted_asset_nodes: - if _has_ancestor_granted(_node): - continue + nodes_related_perms_id = AssetPermission.assets.through.objects.filter( + asset_id__in=asset_ids + ).values_list('assetpermission_id', flat=True) + asset_perm_ids.update(nodes_related_perms_id) - if _node.key not in key2leaf_nodes_mapper: - key2leaf_nodes_mapper[_node.key] = _node - set_asset_granted(key2leaf_nodes_mapper[_node.key]) + cls.add_need_refresh_by_asset_perm_ids(asset_perm_ids) - if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: - process_direct_granted_assets() + @classmethod + def add_need_refresh_by_asset_perm_ids_cross_orgs(cls, asset_perm_ids): + org_id_perm_ids_mapper = defaultdict(set) + pairs = AssetPermission.objects.filter(id__in=asset_perm_ids).values_list('org_id', 'id') + for org_id, perm_id in pairs: + org_id_perm_ids_mapper[org_id].add(perm_id) + for org_id, perm_ids in org_id_perm_ids_mapper.items(): + with tmp_to_org(org_id): + cls.add_need_refresh_by_asset_perm_ids(perm_ids) - leaf_nodes = key2leaf_nodes_mapper.values() + @classmethod + def add_need_refresh_by_asset_perm_ids(cls, asset_perm_ids): + ensure_in_real_or_default_org() - # 计算所有祖先节点 - ancestor_keys = set() - for _node in leaf_nodes: - ancestor_keys.update(_node.get_ancestor_keys()) + group_ids = AssetPermission.user_groups.through.objects.filter( + assetpermission_id__in=asset_perm_ids + ).values_list('usergroup_id', flat=True) - # 从祖先节点 key 中去掉同时也是叶子节点的 key - ancestor_keys -= key2leaf_nodes_mapper.keys() - # 查出祖先节点 - ancestors = Node.objects.filter(key__in=ancestor_keys).only(*node_only_fields) - return [*leaf_nodes, *ancestors] + user_ids = set() + direct_user_id = AssetPermission.users.through.objects.filter( + assetpermission_id__in=asset_perm_ids + ).values_list('user_id', flat=True) + user_ids.update(direct_user_id) + group_user_ids = User.groups.through.objects.filter( + usergroup_id__in=group_ids + ).values_list('user_id', flat=True) + user_ids.update(group_user_ids) -def create_mapping_nodes(user, nodes): - to_create = [] - for node in nodes: - _granted = getattr(node, TMP_GRANTED_FIELD, False) - _asset_granted = getattr(node, TMP_ASSET_GRANTED_FIELD, False) - _granted_assets_amount = getattr(node, TMP_GRANTED_ASSETS_AMOUNT_FIELD, 0) - to_create.append(UserGrantedMappingNode( - user=user, - node=node, - key=node.key, - parent_key=node.parent_key, - granted=_granted, - asset_granted=_asset_granted, - assets_amount=_granted_assets_amount, - )) - - UserGrantedMappingNode.objects.bulk_create(to_create) - - -def set_node_granted_assets_amount(user, node, asset_perms_id=None): - """ - 不依赖`UserGrantedMappingNode`直接查询授权计算资产数量 - """ - _granted = getattr(node, TMP_GRANTED_FIELD, False) - if _granted: - assets_amount = node.assets_amount - else: - if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: - assets_amount = count_direct_granted_node_assets(user, node.key, asset_perms_id) - else: - assets_amount = count_node_all_granted_assets(user, node.key, asset_perms_id) - setattr(node, TMP_GRANTED_ASSETS_AMOUNT_FIELD, assets_amount) - - -@tmp_to_root_org() -def rebuild_user_mapping_nodes(user): - logger.info(f'>>> {dt_formater(now())} start rebuild {user} mapping nodes') - - # 先删除旧的授权树🌲 - UserGrantedMappingNode.objects.filter(user=user).delete() - asset_perms_id = get_user_all_assetpermissions_id(user) - if not asset_perms_id: - # 没有授权直接返回 - return - tmp_nodes = compute_tmp_mapping_node_from_perm(user, asset_perms_id=asset_perms_id) - for _node in tmp_nodes: - set_node_granted_assets_amount(user, _node, asset_perms_id) - create_mapping_nodes(user, tmp_nodes) - logger.info(f'>>> {dt_formater(now())} end rebuild {user} mapping nodes') - - -def rebuild_all_user_mapping_nodes(): - from users.models import User - users = User.objects.all() - for user in users: - rebuild_user_mapping_nodes(user) - - -def get_user_granted_nodes_list_via_mapping_node(user): - """ - 这里的 granted nodes, 是整棵树需要的node,推算出来的也算 - :param user: - :return: - """ - # 获取 `UserGrantedMappingNode` 中对应的 `Node` - nodes = Node.objects.filter( - mapping_nodes__user=user, - ).annotate( - **node_annotate_mapping_node - ).distinct() - - key_to_node_mapper = {} - nodes_descendant_q = Q() - - for node in nodes: - if not is_direct_granted_by_annotate(node): - # 未授权的节点资产数量设置为 `UserGrantedMappingNode` 中的数量 - node.assets_amount = get_granted_assets_amount(node) - else: - # 直接授权的节点 - # 增加查询后代节点的过滤条件 - nodes_descendant_q |= Q(key__startswith=f'{node.key}:') - key_to_node_mapper[node.key] = node - - if nodes_descendant_q: - descendant_nodes = Node.objects.filter( - nodes_descendant_q - ).annotate( - **node_annotate_set_granted + cls.remove_builed_orgs_from_users( + [current_org.id], user_ids ) - for node in descendant_nodes: - key_to_node_mapper[node.key] = node - all_nodes = key_to_node_mapper.values() - return all_nodes + @lazyproperty + def orgs_id(self): + ret = [org.id for org in self.orgs] + return ret + + @lazyproperty + def orgs(self): + orgs = [*self.user.orgs.all(), Organization.default()] + return orgs + + @timeit + def refresh_if_need(self, force=False): + user = self.user + exists = UserAssetGrantedTreeNodeRelation.objects.filter(user=user).exists() + + if force or not exists: + orgs = self.orgs + self.set_all_orgs_as_builed() + else: + orgs = self.get_need_refresh_orgs_and_fill_up() + + for org in orgs: + with tmp_to_org(org): + utils = UserGrantedTreeBuildUtils(user) + utils.rebuild_user_granted_tree() -def get_user_granted_all_assets( - user, via_mapping_node=True, - include_direct_granted_assets=True, asset_perms_id=None): - if asset_perms_id is None: - asset_perms_id = get_user_all_assetpermissions_id(user) +class UserGrantedUtilsBase: + user: User - if via_mapping_node: - granted_node_keys = UserGrantedMappingNode.objects.filter( - user=user, granted=True, - ).values_list('key', flat=True).distinct() - else: - granted_node_keys = Node.objects.filter( - granted_by_permissions__id__in=asset_perms_id - ).distinct().values_list('key', flat=True) - granted_node_keys = Node.clean_children_keys(granted_node_keys) + def __init__(self, user, asset_perm_ids=None): + self.user = user + self._asset_perm_ids = asset_perm_ids - granted_node_q = Q() - for _key in granted_node_keys: - granted_node_q |= Q(nodes__key__startswith=f'{_key}:') - granted_node_q |= Q(nodes__key=_key) + @lazyproperty + def asset_perm_ids(self) -> set: + if self._asset_perm_ids: + return self._asset_perm_ids - if include_direct_granted_assets: - assets__id = get_user_direct_granted_assets(user, asset_perms_id).values_list('id', flat=True) - q = granted_node_q | Q(id__in=list(assets__id)) - else: - q = granted_node_q - - if q: - return Asset.org_objects.filter(q).distinct() - else: - return Asset.org_objects.none() + asset_perm_ids = get_user_all_asset_perm_ids(self.user) + return asset_perm_ids -def get_node_all_granted_assets(user: User, key): - """ - 此算法依据 `UserGrantedMappingNode` 的数据查询 - 1. 查询该节点下的直接授权节点 - 2. 查询该节点下授权资产关联的节点 - """ +class UserGrantedTreeBuildUtils(UserGrantedUtilsBase): - assets = Asset.objects.none() + def get_direct_granted_nodes(self) -> NodeQuerySet: + # 查询直接授权节点 + nodes = PermNode.objects.filter( + granted_by_permissions__id__in=self.asset_perm_ids + ).distinct() + return nodes - # 查询该节点下的授权节点 - granted_mapping_nodes = UserGrantedMappingNode.objects.filter( - user=user, granted=True, - ).filter( - Q(key__startswith=f'{key}:') | Q(key=key) - ) + @lazyproperty + def direct_granted_asset_ids(self) -> list: + # 3.15 + asset_ids = AssetPermission.assets.through.objects.filter( + assetpermission_id__in=self.asset_perm_ids + ).annotate( + asset_id_str=output_as_string('asset_id') + ).values_list( + 'asset_id_str', flat=True + ).distinct() - # 根据授权节点构建资产查询条件 - granted_nodes_qs = [] - for _node in granted_mapping_nodes: - granted_nodes_qs.append(Q(nodes__key__startswith=f'{_node.key}:')) - granted_nodes_qs.append(Q(nodes__key=_node.key)) + asset_ids = list(asset_ids) + return asset_ids - # 查询该节点下的资产授权节点 - only_asset_granted_mapping_nodes = UserGrantedMappingNode.objects.filter( - user=user, - asset_granted=True, - granted=False, - ).filter(Q(key__startswith=f'{key}:') | Q(key=key)) + @timeit + def rebuild_user_granted_tree(self): + ensure_in_real_or_default_org() + logger.info(f'Rebuild user:{self.user} tree in org:{current_org}') - # 根据资产授权节点构建查询 - only_asset_granted_nodes_qs = [] - for _node in only_asset_granted_mapping_nodes: - only_asset_granted_nodes_qs.append(Q(nodes__id=_node.node_id)) + user = self.user + org_id = current_org.id - q = [] - if granted_nodes_qs: - q.append(reduce(or_, granted_nodes_qs)) + with UserGrantedTreeRebuildLock(org_id, user.id): + # 先删除旧的授权树🌲 + UserAssetGrantedTreeNodeRelation.objects.filter(user=user).delete() - if only_asset_granted_nodes_qs: - only_asset_granted_nodes_q = reduce(or_, only_asset_granted_nodes_qs) - asset_perms_id = get_user_all_assetpermissions_id(user) - only_asset_granted_nodes_q &= Q(granted_by_permissions__id__in=list(asset_perms_id)) - q.append(only_asset_granted_nodes_q) + if not self.asset_perm_ids: + # 没有授权直接返回 + return - if q: - assets = Asset.objects.filter(reduce(or_, q)).distinct() - return assets + nodes = self.compute_perm_nodes_tree() + self.compute_node_assets_amount(nodes) + if not nodes: + return + self.create_mapping_nodes(nodes) + + @timeit + def compute_perm_nodes_tree(self, node_only_fields=NODE_ONLY_FIELDS) -> list: + + # 查询直接授权节点 + nodes = self.get_direct_granted_nodes().only(*node_only_fields) + nodes = list(nodes) + + # 授权的节点 key 集合 + granted_key_set = {_node.key for _node in nodes} + + def _has_ancestor_granted(node: PermNode): + """ + 判断一个节点是否有授权过的祖先节点 + """ + ancestor_keys = set(node.get_ancestor_keys()) + return ancestor_keys & granted_key_set + + key2leaf_nodes_mapper = {} + + # 给授权节点设置 granted 标识,同时去重 + for node in nodes: + node: PermNode + if _has_ancestor_granted(node): + continue + node.node_from = NodeFrom.granted + key2leaf_nodes_mapper[node.key] = node + + # 查询授权资产关联的节点设置 + def process_direct_granted_assets(): + # 查询直接授权资产 + nodes_id = {node_id_str for node_id_str, _ in self.direct_granted_asset_id_node_id_str_pairs} + # 查询授权资产关联的节点设置 2.80 + granted_asset_nodes = PermNode.objects.filter( + id__in=nodes_id + ).distinct().only(*node_only_fields) + granted_asset_nodes = list(granted_asset_nodes) + + # 给资产授权关联的节点设置 is_asset_granted 标识,同时去重 + for node in granted_asset_nodes: + if _has_ancestor_granted(node): + continue + if node.key in key2leaf_nodes_mapper: + continue + node.node_from = NodeFrom.asset + key2leaf_nodes_mapper[node.key] = node + + if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: + process_direct_granted_assets() + + leaf_nodes = key2leaf_nodes_mapper.values() + + # 计算所有祖先节点 + ancestor_keys = set() + for node in leaf_nodes: + ancestor_keys.update(node.get_ancestor_keys()) + + # 从祖先节点 key 中去掉同时也是叶子节点的 key + ancestor_keys -= key2leaf_nodes_mapper.keys() + # 查出祖先节点 + ancestors = PermNode.objects.filter(key__in=ancestor_keys).only(*node_only_fields) + ancestors = list(ancestors) + for node in ancestors: + node.node_from = NodeFrom.child + result = [*leaf_nodes, *ancestors] + return result + + @timeit + def create_mapping_nodes(self, nodes): + user = self.user + to_create = [] + + for node in nodes: + to_create.append(UserAssetGrantedTreeNodeRelation( + user=user, + node=node, + node_key=node.key, + node_parent_key=node.parent_key, + node_from=node.node_from, + node_assets_amount=node.assets_amount, + org_id=node.org_id + )) + + UserAssetGrantedTreeNodeRelation.objects.bulk_create(to_create) + + @timeit + def _fill_direct_granted_node_assets_id_from_mem(self, nodes_key, mapper): + org_id = current_org.id + for key in nodes_key: + assets_id = PermNode.get_all_assets_id_by_node_key(org_id, key) + mapper[key].update(assets_id) + + @lazyproperty + def direct_granted_asset_id_node_id_str_pairs(self): + node_asset_pairs = Asset.nodes.through.objects.filter( + asset_id__in=self.direct_granted_asset_ids + ).annotate( + asset_id_str=output_as_string('asset_id'), + node_id_str=output_as_string('node_id') + ).values_list( + 'node_id_str', 'asset_id_str' + ) + node_asset_pairs = list(node_asset_pairs) + return node_asset_pairs + + @timeit + def compute_node_assets_amount(self, nodes: List[PermNode]): + """ + 这里计算的是一个组织的 + """ + # 直接授权了根节点,直接计算 + if len(nodes) == 1: + node = nodes[0] + if node.node_from == NodeFrom.granted and node.key.isdigit(): + with tmp_to_org(node.org): + node.granted_assets_amount = len(node.get_all_assets_id()) + return + + direct_granted_nodes_key = [] + node_id_key_mapper = {} + for node in nodes: + if node.node_from == NodeFrom.granted: + direct_granted_nodes_key.append(node.key) + node_id_key_mapper[node.id.hex] = node.key + + # 授权的节点和直接资产的映射 + nodekey_assetsid_mapper = defaultdict(set) + # 直接授权的节点,资产从完整树过来 + self._fill_direct_granted_node_assets_id_from_mem( + direct_granted_nodes_key, nodekey_assetsid_mapper + ) + + # 处理直接授权资产 + # 直接授权资产,取节点与资产的关系 + node_asset_pairs = self.direct_granted_asset_id_node_id_str_pairs + node_asset_pairs = list(node_asset_pairs) + + for node_id, asset_id in node_asset_pairs: + nkey = node_id_key_mapper[node_id] + nodekey_assetsid_mapper[nkey].add(asset_id) + + util = NodeAssetsUtil(nodes, nodekey_assetsid_mapper) + util.generate() + + for node in nodes: + assets_amount = util.get_assets_amount(node.key) + node.assets_amount = assets_amount + + def get_whole_tree_nodes(self) -> list: + node_only_fields = NODE_ONLY_FIELDS + ('value', 'full_value') + nodes = self.compute_perm_nodes_tree(node_only_fields=node_only_fields) + self.compute_node_assets_amount(nodes) + + # 查询直接授权节点的子节点 + q = Q() + for node in self.get_direct_granted_nodes().only('key'): + q |= Q(key__startswith=f'{node.key}:') + + if q: + descendant_nodes = PermNode.objects.filter(q).distinct() + else: + descendant_nodes = PermNode.objects.none() + + nodes.extend(descendant_nodes) + return nodes -def get_direct_granted_node_ids(user: User, key, asset_perms_id=None): - if asset_perms_id is None: - asset_perms_id = get_user_all_assetpermissions_id(user) +class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): - # 先查出该节点下的直接授权节点 - granted_nodes = Node.objects.filter( - Q(key__startswith=f'{key}:') | Q(key=key) - ).filter( - granted_by_permissions__id__in=asset_perms_id - ).distinct().only('id', 'key') + def get_favorite_assets(self, qs_stage: QuerySetStage = None, only=('id', )) -> AssetQuerySet: + favorite_asset_ids = FavoriteAsset.objects.filter( + user=self.user + ).values_list('asset_id', flat=True) + favorite_asset_ids = list(favorite_asset_ids) + qs_stage = qs_stage or QuerySetStage() + qs_stage.filter(id__in=favorite_asset_ids).only(*only) + assets = self.get_all_granted_assets(qs_stage) + return assets - node_ids = set() - # 根据直接授权节点查询他们的子节点 - q = Q() - for _node in granted_nodes: - q |= Q(key__startswith=f'{_node.key}:') - node_ids.add(_node.id) + def get_ungroup_assets(self) -> AssetQuerySet: + return self.get_direct_granted_assets() - if q: - descendant_ids = Node.objects.filter(q).values_list('id', flat=True).distinct() - node_ids.update(descendant_ids) - return node_ids + def get_direct_granted_assets(self) -> AssetQuerySet: + queryset = Asset.objects.order_by().filter( + granted_by_permissions__id__in=self.asset_perm_ids + ).distinct() + return queryset + + def get_direct_granted_nodes_assets(self, qs_stage: QuerySetStage = None) -> AssetQuerySet: + granted_node_ids = AssetPermission.nodes.through.objects.filter( + assetpermission_id__in=self.asset_perm_ids + ).values_list('node_id', flat=True).distinct() + granted_node_ids = list(granted_node_ids) + granted_nodes = PermNode.objects.filter(id__in=granted_node_ids).only('id', 'key') + queryset = PermNode.get_nodes_all_assets(*granted_nodes) + if qs_stage: + queryset = qs_stage.merge(queryset) + return queryset + + def get_all_granted_assets(self, qs_stage: QuerySetStage = None) -> AssetQuerySet: + nodes_assets = self.get_direct_granted_nodes_assets() + assets = self.get_direct_granted_assets() + + if qs_stage: + nodes_assets, assets = qs_stage.merge_multi_before_union(nodes_assets, assets) + queryset = nodes_assets.union(assets) + if qs_stage: + queryset = qs_stage.merge_after_union(queryset) + return queryset + + def get_node_all_assets(self, id, qs_stage: QuerySetStage = None) -> Tuple[PermNode, QuerySet]: + node = PermNode.objects.get(id=id) + granted_status = node.get_granted_status(self.user) + if granted_status == NodeFrom.granted: + assets = PermNode.get_nodes_all_assets(node) + if qs_stage: + assets = qs_stage.merge(assets) + return node, assets + elif granted_status in (NodeFrom.asset, NodeFrom.child): + node.use_granted_assets_amount() + assets = self._get_indirect_granted_node_all_assets(node, qs_stage=qs_stage) + return node, assets + else: + node.assets_amount = 0 + return node, Asset.objects.none() + + def get_node_assets(self, key) -> AssetQuerySet: + node = PermNode.objects.get(key=key) + granted_status = node.get_granted_status(self.user) + + if granted_status == NodeFrom.granted: + assets = Asset.objects.order_by().filter(nodes_id=node.id) + return assets + elif granted_status == NodeFrom.asset: + return self._get_indirect_granted_node_assets(node.id) + else: + return Asset.objects.none() + + def _get_indirect_granted_node_assets(self, id) -> AssetQuerySet: + assets = Asset.objects.order_by().filter(nodes_id=id) & self.get_direct_granted_assets() + return assets + + def _get_indirect_granted_node_all_assets(self, node, qs_stage: QuerySetStage = None) -> QuerySet: + """ + 此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询 + 1. 查询该节点下的直接授权节点 + 2. 查询该节点下授权资产关联的节点 + """ + user = self.user + + # 查询该节点下的授权节点 + granted_nodes = UserAssetGrantedTreeNodeRelation.objects.filter( + user=user, node_from=NodeFrom.granted + ).filter( + Q(node_key__startswith=f'{node.key}:') + ).only('node_id', 'node_key') + node_assets = PermNode.get_nodes_all_assets(*granted_nodes) + + # 查询该节点下的资产授权节点 + only_asset_granted_node_ids = UserAssetGrantedTreeNodeRelation.objects.filter( + user=user, node_from=NodeFrom.asset + ).filter( + Q(node_key__startswith=f'{node.key}:') + ).values_list('node_id', flat=True) + + only_asset_granted_node_ids = list(only_asset_granted_node_ids) + if node.node_from == NodeFrom.asset: + only_asset_granted_node_ids.append(node.id) + + assets = Asset.objects.filter( + nodes__id__in=only_asset_granted_node_ids, + granted_by_permissions__id__in=self.asset_perm_ids + ).distinct().order_by() + if qs_stage: + node_assets, assets = qs_stage.merge_multi_before_union(node_assets, assets) + granted_assets = node_assets.union(assets) + granted_assets = qs_stage.merge_after_union(granted_assets) + return granted_assets -def get_node_all_granted_assets_from_perm(user: User, key, asset_perms_id=None): - """ - 此算法依据 `AssetPermission` 的数据查询 - 1. 查询该节点下的直接授权节点 - 2. 查询该节点下授权资产关联的节点 - """ - if asset_perms_id is None: - asset_perms_id = get_user_all_assetpermissions_id(user) +class UserGrantedNodesQueryUtils(UserGrantedUtilsBase): + def get_node_children(self, key): + if not key: + return self.get_top_level_nodes() - # 直接授权资产查询条件 - q = ( - Q(nodes__key__startswith=f'{key}:') | Q(nodes__key=key) - ) & Q(granted_by_permissions__id__in=asset_perms_id) + node = PermNode.objects.get(key=key) + granted_status = node.get_granted_status(self.user) + if granted_status == NodeFrom.granted: + return PermNode.objects.filter(parent_key=key) + elif granted_status in (NodeFrom.asset, NodeFrom.child): + return self.get_indirect_granted_node_children(key) + else: + return PermNode.objects.none() - node_ids = get_direct_granted_node_ids(user, key, asset_perms_id) - q |= Q(nodes__id__in=node_ids) - asset_qs = Asset.objects.filter(q).distinct() - return asset_qs + def get_indirect_granted_node_children(self, key): + """ + 获取用户授权树中未授权节点的子节点 + 只匹配在 `UserAssetGrantedTreeNodeRelation` 中存在的节点 + """ + user = self.user + nodes = PermNode.objects.filter( + granted_node_rels__user=user, + parent_key=key + ).annotate( + **PermNode.annotate_granted_node_rel_fields + ).distinct() + # 设置节点授权资产数量 + for node in nodes: + node.use_granted_assets_amount() + return nodes -def get_direct_granted_node_assets_from_perm(user: User, key, asset_perms_id=None): - node_ids = get_direct_granted_node_ids(user, key, asset_perms_id) - asset_qs = Asset.objects.filter(nodes__id__in=node_ids).distinct() - return asset_qs + def get_top_level_nodes(self): + nodes = self.get_special_nodes() + nodes.extend(self.get_indirect_granted_node_children('')) + return nodes + def get_ungrouped_node(self): + assets_util = UserGrantedAssetsQueryUtils(self.user, self.asset_perm_ids) + assets_amount = assets_util.get_direct_granted_assets().count() + return PermNode.get_ungrouped_node(assets_amount) -def count_node_all_granted_assets(user: User, key, asset_perms_id=None): - return get_node_all_granted_assets_from_perm(user, key, asset_perms_id).count() + def get_favorite_node(self): + assets_query_utils = UserGrantedAssetsQueryUtils(self.user, self.asset_perm_ids) + assets_amount = assets_query_utils.get_favorite_assets().values_list('id').count() + return PermNode.get_favorite_node(assets_amount) + def get_special_nodes(self): + nodes = [] + if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: + ungrouped_node = self.get_ungrouped_node() + nodes.append(ungrouped_node) + favorite_node = self.get_favorite_node() + nodes.append(favorite_node) + return nodes -def count_direct_granted_node_assets(user: User, key, asset_perms_id=None): - return get_direct_granted_node_assets_from_perm(user, key, asset_perms_id).count() + @timeit + def get_whole_tree_nodes(self, with_special=True): + """ + 这里的 granted nodes, 是整棵树需要的node,推算出来的也算 + :param user: + :return: + """ + nodes = PermNode.objects.filter( + granted_node_rels__user=self.user + ).annotate( + **PermNode.annotate_granted_node_rel_fields + ).distinct() + key_to_node_mapper = {} + nodes_descendant_q = Q() -def get_indirect_granted_node_children(user, key=''): - """ - 获取用户授权树中未授权节点的子节点 - 只匹配在 `UserGrantedMappingNode` 中存在的节点 - """ - nodes = Node.objects.filter( - mapping_nodes__user=user, - parent_key=key - ).annotate( - _granted_assets_amount=F('mapping_nodes__assets_amount'), - _granted=F('mapping_nodes__granted') - ).distinct() + for node in nodes: + node.use_granted_assets_amount() - # 设置节点授权资产数量 - for _node in nodes: - if not is_direct_granted_by_annotate(_node): - _node.assets_amount = get_granted_assets_amount(_node) - return nodes + if node.node_from == NodeFrom.granted: + # 直接授权的节点 + # 增加查询后代节点的过滤条件 + nodes_descendant_q |= Q(key__startswith=f'{node.key}:') + key_to_node_mapper[node.key] = node - -def get_top_level_granted_nodes(user): - nodes = list(get_indirect_granted_node_children(user, key='')) - if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: - ungrouped_node = get_ungrouped_node(user) - nodes.insert(0, ungrouped_node) - favorite_node = get_favorite_node(user) - nodes.insert(0, favorite_node) - return nodes - - -def get_user_all_assetpermissions_id(user: User): - asset_perms_id = AssetPermission.objects.valid().filter( - Q(users=user) | Q(user_groups__users=user) - ).distinct().values_list('id', flat=True) - - # !!! 这个很重要,必须转换成 list,避免 Django 生成嵌套子查询 - asset_perms_id = list(asset_perms_id) - return asset_perms_id - - -def get_user_direct_granted_assets(user, asset_perms_id=None): - if asset_perms_id is None: - asset_perms_id = get_user_all_assetpermissions_id(user) - assets = Asset.org_objects.filter(granted_by_permissions__id__in=asset_perms_id).distinct() - return assets - - -def count_user_direct_granted_assets(user, asset_perms_id=None): - count = get_user_direct_granted_assets( - user, asset_perms_id=asset_perms_id - ).values_list('id').count() - return count - - -def get_ungrouped_node(user, asset_perms_id=None): - assets_amount = count_user_direct_granted_assets(user, asset_perms_id) - return Node( - id=UNGROUPED_NODE_KEY, - key=UNGROUPED_NODE_KEY, - value=UNGROUPED_NODE_VALUE, - assets_amount=assets_amount - ) - - -def get_favorite_node(user, asset_perms_id=None): - assets_amount = FavoriteAsset.get_user_favorite_assets( - user, asset_perms_id=asset_perms_id - ).values_list('id').count() - return Node( - id=FAVORITE_NODE_KEY, - key=FAVORITE_NODE_KEY, - value=FAVORITE_NODE_VALUE, - assets_amount=assets_amount - ) - - -def rebuild_user_tree_if_need(request, user): - """ - 升级授权树策略后,用户的数据可能还未初始化,为防止用户显示没有数据 - 先检查 MappingNode 如果没有数据,同步创建用户授权树 - """ - if is_true(request.query_params.get('rebuild_tree')) or \ - not UserGrantedMappingNode.objects.filter(user=user).exists(): - try: - rebuild_user_mapping_nodes_with_lock(user) - except lock.SomeoneIsDoingThis: - # 您的数据正在初始化,请稍等 - raise lock.SomeoneIsDoingThis( - detail=_('Please wait while your data is being initialized'), - code='rebuild_tree_conflict' + if nodes_descendant_q: + descendant_nodes = PermNode.objects.filter( + nodes_descendant_q ) + for node in descendant_nodes: + key_to_node_mapper[node.key] = node + + all_nodes = [] + if with_special: + special_nodes = self.get_special_nodes() + all_nodes.extend(special_nodes) + all_nodes.extend(key_to_node_mapper.values()) + return all_nodes diff --git a/utils/generate_fake_data/resources/assets.py b/utils/generate_fake_data/resources/assets.py index a0edc9f08..4c5cbba93 100644 --- a/utils/generate_fake_data/resources/assets.py +++ b/utils/generate_fake_data/resources/assets.py @@ -5,7 +5,6 @@ import forgery_py from .base import FakeDataGenerator from assets.models import * -from assets.utils import check_node_assets_amount class AdminUsersGenerator(FakeDataGenerator): @@ -93,4 +92,4 @@ class AssetsGenerator(FakeDataGenerator): self.set_assets_nodes(creates) def after_generate(self): - check_node_assets_amount() + pass