refactor tree (重构&优化资产树/用户授权树加载速度) (#5548) (#5549)

* Bai reactor tree ( 重构获取完整资产树中节点下资产总数的逻辑) (#5548)

* tree: v0.1

* tree: v0.2

* tree: v0.3

* tree: v0.4

* tree: 添加并发锁未请求到时的debug日志

* 以空间换时间的方式优化资产树

* Reactor tree togther v2 (#5576)

* Bai reactor tree ( 重构获取完整资产树中节点下资产总数的逻辑) (#5548)

* tree: v0.1

* tree: v0.2

* tree: v0.3

* tree: v0.4

* tree: 添加并发锁未请求到时的debug日志

* 以空间换时间的方式优化资产树

* 修改授权适配新方案

* 添加树处理工具

* 完成新的用户授权树计算以及修改一些信号

* 重构了获取资产的一些 api

* 重构了一些节点的api

* 整理了一些代码

* 完成了api 的重构

* 重构检查节点数量功能

* 完成重构授权树工具类

* api 添加强制刷新参数

* 整理一些信号

* 处理一些信号的问题

* 完成了信号的处理

* 重构了资产树相关的锁机制

* RebuildUserTreeTask 还得添加回来

* 优化下不能在root组织的检查函数

* 优化资产树变化时锁的使用

* 修改一些算法的小工具

* 资产树锁不再校验是否在具体组织里

* 整理了一些信号的位置

* 修复资产与节点关系维护的bug

* 去掉一些调试代码

* 修复资产授权过期检查刷新授权树的 bug

* 添加了可重入锁

* 添加一些计时,优化一些sql

* 增加 union 查询的支持

* 尝试用 sql 解决节点资产数量问题

* 开始优化计算授权树节点资产数量不用冗余表

* 新代码能跑起来了,修复一下bug

* 去掉 UserGrantedMappingNode 换成 UserAssetGrantedTreeNodeRelation

* 修了些bug,做了些优化

* 优化QuerySetStage 执行逻辑

* 与小白的内存结合了

* 删掉老的表,迁移新的 assets_amount 字段

* 优化用户授权页面资产列表 count 慢

* 修复批量命令数量不对

* 修改获取非直接授权节点的 children 的逻辑

* 获取整棵树的节点

* 回退锁

* 整理迁移脚本

* 改变更新树策略

* perf: 修改一波缩进

* fix: 修改handler名称

* 修复授权树获取资产sql 泛滥

* 修复授权规则有效bug

* 修复一些bug

* 修复一些bug

* 又修了一些小bug

* 去掉了老的 get_nodes_all_assets

* 修改一些写法

* Reactor tree togther b2 (#5570)

* fix: 修改handler名称

* perf: 优化生成树

* perf: 去掉注释

* 优化了一些

* 重新生成迁移脚本

* 去掉周期检查节点资产数量的任务

* Pr@reactor tree togther guang@perf mapping (#5573)

* fix: 修改handler名称

* perf: mapping 拆分出来

* 修改名称

* perf: 修改锁名

* perf: 去掉检查节点任务

* perf: 修改一下名称

* perf: 优化一波

Co-authored-by: Jiangjie.Bai <32935519+BaiJiangJie@users.noreply.github.com>
Co-authored-by: Bai <bugatti_it@163.com>
Co-authored-by: xinwen <coderWen@126.com>

Co-authored-by: xinwen <coderWen@126.com>
Co-authored-by: 老广 <ibuler@qq.com>
pull/5610/head
Jiangjie.Bai 2021-02-05 13:29:29 +08:00 committed by GitHub
parent 709e7af953
commit 7cf6e54f01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 1829 additions and 1480 deletions

View File

@ -3,10 +3,10 @@
from assets.api import FilterAssetByNodeMixin from assets.api import FilterAssetByNodeMixin
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView from rest_framework.generics import RetrieveAPIView
from rest_framework.response import Response
from rest_framework import status
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.decorators import method_decorator
from assets.locks import NodeTreeUpdateLock
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet

View File

@ -1,5 +1,6 @@
from typing import List from typing import List
from common.utils.common import timeit
from assets.models import Node, Asset from assets.models import Node, Asset
from assets.pagination import AssetLimitOffsetPagination from assets.pagination import AssetLimitOffsetPagination
from common.utils import lazyproperty from common.utils import lazyproperty
@ -7,6 +8,8 @@ from assets.utils import get_node, is_query_node_all_assets
class SerializeToTreeNodeMixin: class SerializeToTreeNodeMixin:
@timeit
def serialize_nodes(self, nodes: List[Node], with_asset_amount=False): def serialize_nodes(self, nodes: List[Node], with_asset_amount=False):
if with_asset_amount: if with_asset_amount:
def _name(node: Node): def _name(node: Node):
@ -43,6 +46,7 @@ class SerializeToTreeNodeMixin:
return platform return platform
return default return default
@timeit
def serialize_assets(self, assets, node_key=None): def serialize_assets(self, assets, node_key=None):
if node_key is None: if node_key is None:
get_pid = lambda asset: getattr(asset, 'parent_key', '') get_pid = lambda asset: getattr(asset, 'parent_key', '')

View File

@ -17,12 +17,9 @@ from common.const.signals import PRE_REMOVE, POST_REMOVE
from assets.models import Asset from assets.models import Asset
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from common.tree import TreeNodeSerializer from common.tree import TreeNodeSerializer
from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY
from orgs.mixins.api import OrgModelViewSet from orgs.mixins.api import OrgModelViewSet
from orgs.mixins import generics from orgs.mixins import generics
from orgs.lock import org_level_transaction_lock
from orgs.utils import current_org from orgs.utils import current_org
from assets.tasks import check_node_assets_amount_task
from ..hands import IsOrgAdmin from ..hands import IsOrgAdmin
from ..models import Node from ..models import Node
from ..tasks import ( from ..tasks import (
@ -31,6 +28,7 @@ from ..tasks import (
) )
from .. import serializers from .. import serializers
from .mixin import SerializeToTreeNodeMixin from .mixin import SerializeToTreeNodeMixin
from assets.locks import NodeTreeUpdateLock
logger = get_logger(__file__) logger = get_logger(__file__)
@ -50,11 +48,6 @@ class NodeViewSet(OrgModelViewSet):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.NodeSerializer serializer_class = serializers.NodeSerializer
@action(methods=[POST], detail=False, url_name='launch-check-assets-amount-task')
def launch_check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
# 仅支持根节点指直接创建子节点下的节点需要通过children接口创建 # 仅支持根节点指直接创建子节点下的节点需要通过children接口创建
def perform_create(self, serializer): def perform_create(self, serializer):
child_key = Node.org_root().get_next_child_key() child_key = Node.org_root().get_next_child_key()
@ -184,9 +177,9 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
if not include_assets: if not include_assets:
return [] return []
assets = self.instance.get_assets().only( assets = self.instance.get_assets().only(
"id", "hostname", "ip", "os", "id", "hostname", "ip", "os", "platform_id",
"org_id", "protocols", "is_active" "org_id", "protocols", "is_active",
) ).prefetch_related('platform')
return self.serialize_assets(assets, self.instance.key) return self.serialize_assets(assets, self.instance.key)
@ -219,8 +212,6 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
return Response("OK") return Response("OK")
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeAddAssetsApi(generics.UpdateAPIView): class NodeAddAssetsApi(generics.UpdateAPIView):
model = Node model = Node
serializer_class = serializers.NodeAssetsSerializer serializer_class = serializers.NodeAssetsSerializer
@ -233,8 +224,6 @@ class NodeAddAssetsApi(generics.UpdateAPIView):
instance.assets.add(*tuple(assets)) instance.assets.add(*tuple(assets))
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeRemoveAssetsApi(generics.UpdateAPIView): class NodeRemoveAssetsApi(generics.UpdateAPIView):
model = Node model = Node
serializer_class = serializers.NodeAssetsSerializer serializer_class = serializers.NodeAssetsSerializer
@ -251,8 +240,6 @@ class NodeRemoveAssetsApi(generics.UpdateAPIView):
Node.org_root().assets.add(*orphan_assets) Node.org_root().assets.add(*orphan_assets)
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class MoveAssetsToNodeApi(generics.UpdateAPIView): class MoveAssetsToNodeApi(generics.UpdateAPIView):
model = Node model = Node
serializer_class = serializers.NodeAssetsSerializer serializer_class = serializers.NodeAssetsSerializer

21
apps/assets/locks.py Normal file
View File

@ -0,0 +1,21 @@
from orgs.utils import current_org
from common.utils.lock import DistributedLock
class NodeTreeUpdateLock(DistributedLock):
name_template = 'assets.node.tree.update.<org_id:{org_id}>'
def get_name(self):
if current_org:
org_id = current_org.id
else:
org_id = 'current_org_is_null'
name = self.name_template.format(
org_id=org_id
)
return name
def __init__(self, blocking=True):
name = self.get_name()
super().__init__(name=name, blocking=blocking,
release_lock_on_transaction_commit=True)

View File

@ -0,0 +1,17 @@
# Generated by Django 3.1 on 2021-02-04 09:49
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('assets', '0065_auto_20210121_1549'),
]
operations = [
migrations.RemoveField(
model_name='node',
name='assets_amount',
),
]

View File

@ -17,7 +17,7 @@ from orgs.mixins.models import OrgModelMixin, OrgManager
from .base import ConnectivityMixin from .base import ConnectivityMixin
from .utils import Connectivity from .utils import Connectivity
__all__ = ['Asset', 'ProtocolsMixin', 'Platform'] __all__ = ['Asset', 'ProtocolsMixin', 'Platform', 'AssetQuerySet']
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,13 +41,6 @@ def default_node():
class AssetManager(OrgManager): class AssetManager(OrgManager):
def get_queryset(self):
return super().get_queryset().annotate(
platform_base=models.F('platform__base')
)
class AssetOrgManager(OrgManager):
pass pass
@ -230,7 +223,6 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
comment = models.TextField(default='', blank=True, verbose_name=_('Comment')) comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
objects = AssetManager.from_queryset(AssetQuerySet)() objects = AssetManager.from_queryset(AssetQuerySet)()
org_objects = AssetOrgManager.from_queryset(AssetQuerySet)()
_connectivity = None _connectivity = None
def __str__(self): def __str__(self):

View File

@ -11,6 +11,7 @@ from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.conf import settings from django.conf import settings
from common.utils.common import timeit
from common.utils import ( from common.utils import (
ssh_key_string_to_obj, ssh_key_gen, get_logger, lazyproperty ssh_key_string_to_obj, ssh_key_gen, get_logger, lazyproperty
) )

View File

@ -18,15 +18,3 @@ class FavoriteAsset(CommonModelMixin):
@classmethod @classmethod
def get_user_favorite_assets_id(cls, user): def get_user_favorite_assets_id(cls, user):
return cls.objects.filter(user=user).values_list('asset', flat=True) return cls.objects.filter(user=user).values_list('asset', flat=True)
@classmethod
def get_user_favorite_assets(cls, user, asset_perms_id=None):
from assets.models import Asset
from perms.utils.asset.user_permission import get_user_granted_all_assets
asset_ids = get_user_granted_all_assets(
user,
via_mapping_node=False,
asset_perms_id=asset_perms_id
).values_list('id', flat=True)
query_name = cls.asset.field.related_query_name()
return Asset.org_objects.filter(**{f'{query_name}__user_id': user.id}, id__in=asset_ids).distinct()

View File

@ -1,23 +1,32 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import uuid
import re import re
import time
import uuid
import threading
import os
import time
import uuid
from collections import defaultdict
from django.db import models, transaction from django.db import models, transaction
from django.db.models import Q from django.db.models import Q, Manager
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext from django.utils.translation import ugettext
from django.db.transaction import atomic from django.db.transaction import atomic
from django.core.cache import cache
from common.utils.lock import DistributedLock
from common.utils.common import timeit
from common.db.models import output_as_string
from common.utils import get_logger from common.utils import get_logger
from common.utils.common import lazyproperty
from orgs.mixins.models import OrgModelMixin, OrgManager from orgs.mixins.models import OrgModelMixin, OrgManager
from orgs.utils import get_current_org, tmp_to_org from orgs.utils import get_current_org, tmp_to_org
from orgs.models import Organization from orgs.models import Organization
__all__ = ['Node', 'FamilyMixin', 'compute_parent_key'] __all__ = ['Node', 'FamilyMixin', 'compute_parent_key', 'NodeQuerySet']
logger = get_logger(__name__) logger = get_logger(__name__)
@ -247,9 +256,125 @@ class FamilyMixin:
return [*tuple(ancestors), self, *tuple(children)] return [*tuple(ancestors), self, *tuple(children)]
class NodeAssetsMixin: class NodeAllAssetsMappingMixin:
# Use a new plan
# { org_id: { node_key: [ asset1_id, asset2_id ] } }
orgid_nodekey_assetsid_mapping = defaultdict(dict)
@classmethod
def get_node_all_assets_id_mapping(cls, org_id):
_mapping = cls.get_node_all_assets_id_mapping_from_memory(org_id)
if _mapping:
return _mapping
_mapping = cls.get_node_all_assets_id_mapping_from_cache_or_generate_to_cache(org_id)
cls.set_node_all_assets_id_mapping_to_memory(org_id, mapping=_mapping)
return _mapping
# from memory
@classmethod
def get_node_all_assets_id_mapping_from_memory(cls, org_id):
mapping = cls.orgid_nodekey_assetsid_mapping.get(org_id, {})
return mapping
@classmethod
def set_node_all_assets_id_mapping_to_memory(cls, org_id, mapping):
cls.orgid_nodekey_assetsid_mapping[org_id] = mapping
@classmethod
def expire_node_all_assets_id_mapping_from_memory(cls, org_id):
org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
# get order: from memory -> (from cache -> to generate)
@classmethod
def get_node_all_assets_id_mapping_from_cache_or_generate_to_cache(cls, org_id):
mapping = cls.get_node_all_assets_id_mapping_from_cache(org_id)
if mapping:
return mapping
lock_key = f'KEY_LOCK_GENERATE_ORG_{org_id}_NODE_ALL_ASSETS_ID_MAPPING'
logger.info(f'Thread[{threading.get_ident()}] acquiring lock[{lock_key}] ...')
with DistributedLock(lock_key):
logger.info(f'Thread[{threading.get_ident()}] acquire lock[{lock_key}] ok')
# 这里使用无限期锁,原因是如果这里卡住了,就卡在数据库了,说明
# 数据库繁忙,所以不应该再有线程执行这个操作,使数据库忙上加忙
# 这里最好先判断内存中有没有,防止同一进程的多个线程重复从 cache 中获取数据,
# 但逻辑过于繁琐,直接判断 cache 吧
_mapping = cls.get_node_all_assets_id_mapping_from_cache(org_id)
if _mapping:
return _mapping
_mapping = cls.generate_node_all_assets_id_mapping(org_id)
cls.set_node_all_assets_id_mapping_to_cache(org_id=org_id, mapping=_mapping)
return _mapping
@classmethod
def get_node_all_assets_id_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id)
mapping = cache.get(cache_key)
return mapping
@classmethod
def set_node_all_assets_id_mapping_to_cache(cls, org_id, mapping):
cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
@classmethod
def expire_node_all_assets_id_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id)
cache.delete(cache_key)
@staticmethod
def _get_cache_key_for_node_all_assets_id_mapping(org_id):
return 'ASSETS_ORG_NODE_ALL_ASSETS_ID_MAPPING_{}'.format(org_id)
@classmethod
def generate_node_all_assets_id_mapping(cls, org_id):
from .asset import Asset
t1 = time.time()
with tmp_to_org(org_id):
nodes_id_key = Node.objects.filter(org_id=org_id) \
.annotate(char_id=output_as_string('id')) \
.values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_assets_id = Asset.nodes.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
node_id_ancestor_keys_mapping = {
node_id: cls.get_node_ancestor_keys(node_key, with_self=True)
for node_id, node_key in nodes_id_key
}
nodeid_assetsid_mapping = defaultdict(set)
for node_id, asset_id in nodes_assets_id:
nodeid_assetsid_mapping[node_id].add(asset_id)
t2 = time.time()
mapping = defaultdict(set)
for node_id, node_key in nodes_id_key:
assets_id = nodeid_assetsid_mapping[node_id]
node_ancestor_keys = node_id_ancestor_keys_mapping[node_id]
for ancestor_key in node_ancestor_keys:
mapping[ancestor_key].update(assets_id)
t3 = time.time()
logger.debug('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2-t1, t3-t2))
return mapping
class NodeAssetsMixin(NodeAllAssetsMappingMixin):
org_id: str
key = '' key = ''
id = None id = None
objects: Manager
def get_all_assets(self): def get_all_assets(self):
from .asset import Asset from .asset import Asset
@ -263,8 +388,7 @@ class NodeAssetsMixin:
# 可是 startswith 会导致表关联时 Asset 索引失效 # 可是 startswith 会导致表关联时 Asset 索引失效
from .asset import Asset from .asset import Asset
node_ids = cls.objects.filter( node_ids = cls.objects.filter(
Q(key__startswith=f'{key}:') | Q(key__startswith=f'{key}:') | Q(key=key)
Q(key=key)
).values_list('id', flat=True).distinct() ).values_list('id', flat=True).distinct()
assets = Asset.objects.filter( assets = Asset.objects.filter(
nodes__id__in=list(node_ids) nodes__id__in=list(node_ids)
@ -283,29 +407,39 @@ class NodeAssetsMixin:
return self.get_all_assets().valid() return self.get_all_assets().valid()
@classmethod @classmethod
def get_nodes_all_assets_ids(cls, nodes_keys): def get_nodes_all_assets_ids_by_keys(cls, nodes_keys):
assets_ids = cls.get_nodes_all_assets(nodes_keys).values_list('id', flat=True) nodes = Node.objects.filter(key__in=nodes_keys)
assets_ids = cls.get_nodes_all_assets(*nodes).values_list('id', flat=True)
return assets_ids return assets_ids
@classmethod @classmethod
def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None): def get_nodes_all_assets(cls, *nodes):
from .asset import Asset from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys) node_ids = set()
q = Q() descendant_node_query = Q()
node_ids = () for n in nodes:
for key in nodes_keys: node_ids.add(n.id)
q |= Q(key__startswith=f'{key}:') descendant_node_query |= Q(key__istartswith=f'{n.key}:')
q |= Q(key=key) if descendant_node_query:
if q: _ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
node_ids = Node.objects.filter(q).distinct().values_list('id', flat=True) node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
q = Q(nodes__id__in=list(node_ids)) @property
if extra_assets_ids: def assets_amount(self):
q |= Q(id__in=extra_assets_ids) assets_id = self.get_all_assets_id()
if q: return len(assets_id)
return Asset.org_objects.filter(q).distinct()
else: def get_all_assets_id(self):
return Asset.objects.none() assets_id = self.get_all_assets_id_by_node_key(org_id=self.org_id, node_key=self.key)
return set(assets_id)
@classmethod
def get_all_assets_id_by_node_key(cls, org_id, node_key):
org_id = str(org_id)
nodekey_assetsid_mapping = cls.get_node_all_assets_id_mapping(org_id)
assets_id = nodekey_assetsid_mapping.get(node_key, [])
return set(assets_id)
class SomeNodesMixin: class SomeNodesMixin:
@ -416,7 +550,6 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
date_create = models.DateTimeField(auto_now_add=True) date_create = models.DateTimeField(auto_now_add=True)
parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"), parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"),
db_index=True, default='') db_index=True, default='')
assets_amount = models.IntegerField(default=0)
objects = OrgManager.from_queryset(NodeQuerySet)() objects = OrgManager.from_queryset(NodeQuerySet)()
is_node = True is_node = True

View File

@ -199,7 +199,7 @@ class SystemUser(BaseUser):
from assets.models import Node from assets.models import Node
nodes_keys = self.nodes.all().values_list('key', flat=True) nodes_keys = self.nodes.all().values_list('key', flat=True)
assets_ids = set(self.assets.all().values_list('id', flat=True)) assets_ids = set(self.assets.all().values_list('id', flat=True))
nodes_assets_ids = Node.get_nodes_all_assets_ids(nodes_keys) nodes_assets_ids = Node.get_nodes_all_assets_ids_by_keys(nodes_keys)
assets_ids.update(nodes_assets_ids) assets_ids.update(nodes_assets_ids)
assets = Asset.objects.filter(id__in=assets_ids) assets = Asset.objects.filter(id__in=assets_ids)
return assets return assets

View File

@ -111,7 +111,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.select_related('admin_user', 'domain', 'platform') queryset = queryset.prefetch_related('admin_user', 'domain', 'platform')
queryset = queryset.prefetch_related('nodes', 'labels') queryset = queryset.prefetch_related('nodes', 'labels')
return queryset return queryset
@ -166,13 +166,6 @@ class AssetDisplaySerializer(AssetSerializer):
'connectivity', 'connectivity',
] ]
@classmethod
def setup_eager_loading(cls, queryset):
queryset = super().setup_eager_loading(queryset)
queryset = queryset\
.annotate(admin_user_username=F('admin_user__username'))
return queryset
class PlatformSerializer(serializers.ModelSerializer): class PlatformSerializer(serializers.ModelSerializer):
meta = serializers.DictField(required=False, allow_null=True) meta = serializers.DictField(required=False, allow_null=True)

View File

@ -0,0 +1,2 @@
from .common import *
from .maintain_nodes_tree import *

View File

@ -1,21 +1,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from operator import add, sub
from assets.utils import is_asset_exists_in_node
from django.db.models.signals import ( from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete, pre_save post_save, m2m_changed, pre_delete, post_delete, pre_save
) )
from django.db.models import Q, F
from django.dispatch import receiver from django.dispatch import receiver
from common.exceptions import M2MReverseNotAllowed from common.exceptions import M2MReverseNotAllowed
from common.const.signals import PRE_ADD, POST_ADD, POST_REMOVE, PRE_CLEAR, PRE_REMOVE from common.const.signals import POST_ADD, POST_REMOVE, PRE_REMOVE
from common.utils import get_logger from common.utils import get_logger
from common.decorator import on_transaction_commit from common.decorator import on_transaction_commit
from .models import Asset, SystemUser, Node, compute_parent_key from assets.models import Asset, SystemUser, Node
from users.models import User from users.models import User
from .tasks import ( from assets.tasks import (
update_assets_hardware_info_util, update_assets_hardware_info_util,
test_asset_connectivity_util, test_asset_connectivity_util,
push_system_user_to_assets_manual, push_system_user_to_assets_manual,
@ -23,7 +19,6 @@ from .tasks import (
add_nodes_assets_to_system_users add_nodes_assets_to_system_users
) )
logger = get_logger(__file__) logger = get_logger(__file__)
@ -202,134 +197,6 @@ def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
m2m_model.objects.bulk_create(to_create) m2m_model.objects.bulk_create(to_create)
def _update_node_assets_amount(node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
def _remove_ancestor_keys(ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
def _update_nodes_asset_amount(node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
_remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = is_asset_exists_in_node(asset_pk, parent_key)
if exists:
_remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@receiver(m2m_changed, sender=Asset.nodes.through)
def update_nodes_assets_amount(action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
_update_node_assets_amount(node, asset_pk_set, operator)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
_update_nodes_asset_amount(node_keys, asset_pk, operator)
RELATED_NODE_IDS = '_related_node_ids' RELATED_NODE_IDS = '_related_node_ids'

View File

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
#
import os
import threading
from django.db.models.signals import (
m2m_changed, post_save, post_delete
)
from django.dispatch import receiver
from django.utils.functional import LazyObject
from common.signals import django_ready
from common.utils.connection import RedisPubSub
from common.utils import get_logger
from assets.models import Asset, Node
logger = get_logger(__file__)
# clear node assets mapping for memory
# ------------------------------------
def get_node_assets_mapping_for_memory_pub_sub():
return RedisPubSub('fm.node_all_assets_id_memory_mapping')
class NodeAssetsMappingForMemoryPubSub(LazyObject):
def _setup(self):
self._wrapped = get_node_assets_mapping_for_memory_pub_sub()
node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub()
def expire_node_assets_mapping_for_memory(org_id):
# 所有进程清除(自己的 memory 数据)
org_id = str(org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
# 当前进程清除(cache 数据)
logger.debug(
"Expire node assets id mapping from cache of org={}, pid={}"
"".format(org_id, os.getpid())
)
Node.expire_node_all_assets_id_mapping_from_cache(org_id)
@receiver(post_save, sender=Node)
def on_node_post_create(sender, instance, created, update_fields, **kwargs):
if created:
need_expire = True
elif update_fields and 'key' in update_fields:
need_expire = True
else:
need_expire = False
if need_expire:
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(django_ready)
def subscribe_node_assets_mapping_expire(sender, **kwargs):
logger.debug("Start subscribe for expire node assets id mapping from memory")
def keep_subscribe():
subscribe = node_assets_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen():
if message["type"] != "message":
continue
org_id = message['data'].decode()
Node.expire_node_all_assets_id_mapping_from_memory(org_id)
logger.debug(
"Expire node assets id mapping from memory of org={}, pid={}"
"".format(str(org_id), os.getpid())
)
t = threading.Thread(target=keep_subscribe)
t.daemon = True
t.start()

View File

@ -12,6 +12,7 @@ __all__ = ['add_nodes_assets_to_system_users']
@tmp_to_root_org() @tmp_to_root_org()
def add_nodes_assets_to_system_users(nodes_keys, system_users): def add_nodes_assets_to_system_users(nodes_keys, system_users):
from ..models import Node from ..models import Node
assets = Node.get_nodes_all_assets(nodes_keys).values_list('id', flat=True) nodes = Node.objects.filter(key__in=nodes_keys)
assets = Node.get_nodes_all_assets(*nodes)
for system_user in system_users: for system_user in system_users:
system_user.assets.add(*tuple(assets)) system_user.assets.add(*tuple(assets))

View File

@ -141,7 +141,8 @@ def gather_asset_users(assets, task_name=None):
@shared_task(queue="ansible") @shared_task(queue="ansible")
def gather_nodes_asset_users(nodes_key): def gather_nodes_asset_users(nodes_key):
assets = Node.get_nodes_all_assets(nodes_key) nodes = Node.objects.filter(key__in=nodes_key)
assets = Node.get_nodes_all_assets(*nodes)
assets_groups_by_100 = [assets[i:i+100] for i in range(0, len(assets), 100)] assets_groups_by_100 = [assets[i:i+100] for i in range(0, len(assets), 100)]
for _assets in assets_groups_by_100: for _assets in assets_groups_by_100:
gather_asset_users(_assets) gather_asset_users(_assets)

View File

@ -1,27 +0,0 @@
from celery import shared_task
from django.utils.translation import gettext_lazy as _
from orgs.models import Organization
from orgs.utils import tmp_to_org
from ops.celery.decorator import register_as_period_task
from assets.utils import check_node_assets_amount
from common.utils.lock import AcquireFailed
from common.utils import get_logger
logger = get_logger(__file__)
@shared_task(queue='celery_heavy_tasks')
def check_node_assets_amount_task(org_id=Organization.ROOT_ID):
try:
with tmp_to_org(Organization.get_instance(org_id)):
check_node_assets_amount()
except AcquireFailed:
logger.error(_('The task of self-checking is already running and cannot be started repeatedly'))
@register_as_period_task(crontab='0 2 * * *')
@shared_task(queue='celery_heavy_tasks')
def check_node_assets_amount_period_task():
check_node_assets_amount_task()

33
apps/assets/tests/tree.py Normal file
View File

@ -0,0 +1,33 @@
from assets.tree import Tree
def test():
from orgs.models import Organization
from assets.models import Node, Asset
import time
Organization.objects.get(id='1863cf22-f666-474e-94aa-935fe175203c').change_to()
t1 = time.time()
nodes = list(Node.objects.exclude(key__startswith='-').only('id', 'key', 'parent_key'))
node_asset_id_pairs = Asset.nodes.through.objects.all().values_list('node_id', 'asset_id')
t2 = time.time()
node_asset_id_pairs = list(node_asset_id_pairs)
tree = Tree(nodes, node_asset_id_pairs)
tree.build_tree()
tree.nodes = None
tree.node_asset_id_pairs = None
import pickle
d = pickle.dumps(tree)
print('------------', len(d))
return tree
tree.compute_tree_node_assets_amount()
print(f'校对算法准确性 ......')
for node in nodes:
tree_node = tree.key_tree_node_mapper[node.key]
if tree_node.assets_amount != node.assets_amount:
print(f'ERROR: {tree_node.assets_amount} {node.assets_amount}')
# print(f'OK {tree_node.asset_amount} {node.assets_amount}')
print(f'数据库时间: {t2 - t1}')
return tree

View File

@ -2,7 +2,6 @@
from django.urls import path, re_path from django.urls import path, re_path
from rest_framework_nested import routers from rest_framework_nested import routers
from rest_framework_bulk.routes import BulkRouter from rest_framework_bulk.routes import BulkRouter
from django.db.transaction import non_atomic_requests
from common import api as capi from common import api as capi
@ -57,9 +56,9 @@ urlpatterns = [
path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'), path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'),
path('nodes/<uuid:pk>/children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'), path('nodes/<uuid:pk>/children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'),
path('nodes/<uuid:pk>/assets/', api.NodeAssetsApi.as_view(), name='node-assets'), path('nodes/<uuid:pk>/assets/', api.NodeAssetsApi.as_view(), name='node-assets'),
path('nodes/<uuid:pk>/assets/add/', non_atomic_requests(api.NodeAddAssetsApi.as_view()), name='node-add-assets'), path('nodes/<uuid:pk>/assets/add/', api.NodeAddAssetsApi.as_view(), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', non_atomic_requests(api.MoveAssetsToNodeApi.as_view()), name='node-replace-assets'), path('nodes/<uuid:pk>/assets/replace/', api.MoveAssetsToNodeApi.as_view(), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', non_atomic_requests(api.NodeRemoveAssetsApi.as_view()), name='node-remove-assets'), path('nodes/<uuid:pk>/assets/remove/', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'),
path('nodes/<uuid:pk>/tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'), path('nodes/<uuid:pk>/tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'),
path('gateways/<uuid:pk>/test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'), path('gateways/<uuid:pk>/test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'),

View File

@ -1,43 +1,16 @@
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
# #
import time from collections import defaultdict
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit
from django.db.models import Q
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none
from common.utils.lock import DistributedLock
from common.http import is_true from common.http import is_true
from .models import Asset, Node from common.struct import Stack
from common.db.models import output_as_string
from .models import Node
logger = get_logger(__file__) logger = get_logger(__file__)
@DistributedLock(name="assets.node.check_node_assets_amount", blocking=False)
def check_node_assets_amount():
for node in Node.objects.all():
logger.info(f'Check node assets amount: {node}')
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
if node.assets_amount != assets_amount:
logger.warn(f'Node wrong assets amount <Node:{node.key}> '
f'{node.assets_amount} right is {assets_amount}')
node.assets_amount = assets_amount
node.save()
# 防止自检程序给数据库的压力太大
time.sleep(0.1)
def is_asset_exists_in_node(asset_pk, node_key):
return Asset.objects.filter(
id=asset_pk
).filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).exists()
def is_query_node_all_assets(request): def is_query_node_all_assets(request):
request = request request = request
query_all_arg = request.query_params.get('all', 'true') query_all_arg = request.query_params.get('all', 'true')
@ -57,3 +30,79 @@ def get_node(request):
else: else:
node = get_object_or_none(Node, key=node_id) node = get_object_or_none(Node, key=node_id)
return node return node
class NodeAssetsInfo:
__slots__ = ('key', 'assets_amount', 'assets')
def __init__(self, key, assets_amount, assets):
self.key = key
self.assets_amount = assets_amount
self.assets = assets
def __str__(self):
return self.key
class NodeAssetsUtil:
def __init__(self, nodes, nodekey_assetsid_mapper):
"""
:param nodes: 节点
:param nodekey_assetsid_mapper: 节点直接资产id的映射 {"key1": set(), "key2": set()}
"""
self.nodes = nodes
# node_id --> set(asset_id1, asset_id2)
self.nodekey_assetsid_mapper = nodekey_assetsid_mapper
self.nodekey_assetsinfo_mapper = {}
@timeit
def generate(self):
# 准备排序好的资产信息数据
infos = []
for node in self.nodes:
assets = self.nodekey_assetsid_mapper.get(node.key, set())
info = NodeAssetsInfo(key=node.key, assets_amount=0, assets=assets)
infos.append(info)
infos = sorted(infos, key=lambda i: [int(i) for i in i.key.split(':')])
# 这个守卫需要添加一下,避免最后一个无法出栈
guarder = NodeAssetsInfo(key='', assets_amount=0, assets=set())
infos.append(guarder)
stack = Stack()
for info in infos:
# 如果栈顶的不是这个节点的父祖节点,那么可以出栈了,可以计算资产数量了
while stack.top and not info.key.startswith(f'{stack.top.key}:'):
pop_info = stack.pop()
pop_info.assets_amount = len(pop_info.assets)
self.nodekey_assetsinfo_mapper[pop_info.key] = pop_info
if not stack.top:
continue
stack.top.assets.update(pop_info.assets)
stack.push(info)
def get_assets_by_key(self, key):
info = self.nodekey_assetsinfo_mapper[key]
return info['assets']
def get_assets_amount(self, key):
info = self.nodekey_assetsinfo_mapper[key]
return info.assets_amount
@classmethod
def test_it(cls):
from assets.models import Node, Asset
nodes = list(Node.objects.all())
nodes_assets = Asset.nodes.through.objects.all()\
.annotate(aid=output_as_string('asset_id'))\
.values_list('node__key', 'aid')
mapping = defaultdict(set)
for key, asset_id in nodes_assets:
mapping[key].add(asset_id)
util = cls(nodes, mapping)
util.generate()
return util

View File

@ -1,2 +0,0 @@
UPDATE_NODE_TREE_LOCK_KEY = 'org_level_transaction_lock_{org_id}_assets_update_node_tree'
UPDATE_MAPPING_NODE_TASK_LOCK_KEY = 'org_level_transaction_lock_{user_id}_update_mapping_node_task'

View File

@ -82,3 +82,7 @@ class JMSModel(JMSBaseModel):
def concated_display(name1, name2): def concated_display(name1, name2):
return Concat(F(name1), Value('('), F(name2), Value(')')) return Concat(F(name1), Value('('), F(name2), Value(')'))
def output_as_string(field_name):
return ExpressionWrapper(F(field_name), output_field=CharField())

View File

@ -254,3 +254,22 @@ def get_disk_usage():
mount_points = [p.mountpoint for p in partitions] mount_points = [p.mountpoint for p in partitions]
usages = {p: psutil.disk_usage(p) for p in mount_points} usages = {p: psutil.disk_usage(p) for p in mount_points}
return usages return usages
class Time:
def __init__(self):
self._timestamps = []
self._msgs = []
def begin(self):
self._timestamps.append(time.time())
def time(self, msg):
self._timestamps.append(time.time())
self._msgs.append(msg)
def print(self):
last, *timestamps = self._timestamps
for timestamp, msg in zip(timestamps, self._msgs):
logger.debug(f'TIME_IT: {msg} {timestamp-last}')
last = timestamp

View File

@ -1,8 +1,9 @@
from functools import wraps from functools import wraps
import threading import threading
from redis_lock import Lock as RedisLock from redis_lock import Lock as RedisLock, NotAcquired
from redis import Redis from redis import Redis
from django.db import transaction
from common.utils import get_logger from common.utils import get_logger
from common.utils.inspect import copy_function_args from common.utils.inspect import copy_function_args
@ -16,7 +17,8 @@ class AcquireFailed(RuntimeError):
class DistributedLock(RedisLock): class DistributedLock(RedisLock):
def __init__(self, name, blocking=True, expire=60*2, auto_renewal=True): def __init__(self, name, blocking=True, expire=None, release_lock_on_transaction_commit=False,
release_raise_exc=False, auto_renewal_seconds=60*2):
""" """
使用 redis 构造的分布式锁 使用 redis 构造的分布式锁
@ -25,31 +27,46 @@ class DistributedLock(RedisLock):
:param blocking: :param blocking:
该参数只在锁作为装饰器或者 `with` 时有效 该参数只在锁作为装饰器或者 `with` 时有效
:param expire: :param expire:
锁的过期时间注意不一定是锁到这个时间就释放了分两种情况 锁的过期时间
`auto_renewal=False` 锁会释放 :param release_lock_on_transaction_commit:
`auto_renewal=True` 如果过期之前程序还没释放锁我们会延长锁的存活时间 是否在当前事务结束后再释放锁
这里的作用是防止程序意外终止没有释放锁导致死锁 :param release_raise_exc:
释放锁时如果没有持有锁是否抛异常或静默
:param auto_renewal_seconds:
当持有一个无限期锁的时候刷新锁的时间具体参考 `redis_lock.Lock#auto_renewal`
""" """
self.kwargs_copy = copy_function_args(self.__init__, locals()) self.kwargs_copy = copy_function_args(self.__init__, locals())
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD) redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
if expire is None:
expire = auto_renewal_seconds
auto_renewal = True
else:
auto_renewal = False
super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal) super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal)
self._blocking = blocking self._blocking = blocking
self._release_lock_on_transaction_commit = release_lock_on_transaction_commit
self._release_raise_exc = release_raise_exc
def __enter__(self): def __enter__(self):
thread_id = threading.current_thread().ident thread_id = threading.current_thread().ident
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> attempt to acquire <lock:{self._name}> ...') logger.debug(f'Attempt to acquire global lock: thread {thread_id} lock {self._name}')
acquired = self.acquire(blocking=self._blocking) acquired = self.acquire(blocking=self._blocking)
if self._blocking and not acquired: if self._blocking and not acquired:
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> was not acquired <lock:{self._name}>, but blocking=True') logger.debug(f'Not acquired lock, but blocking=True, thread {thread_id} lock {self._name}')
raise EnvironmentError("Lock wasn't acquired, but blocking=True") raise EnvironmentError("Lock wasn't acquired, but blocking=True")
if not acquired: if not acquired:
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> acquire <lock:{self._name}> failed') logger.debug(f'Not acquired the lock, thread {thread_id} lock {self._name}')
raise AcquireFailed raise AcquireFailed
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> acquire <lock:{self._name}> ok') logger.debug(f'Acquire lock success, thread {thread_id} lock {self._name}')
return self return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None): def __exit__(self, exc_type=None, exc_value=None, traceback=None):
self.release() if self._release_lock_on_transaction_commit:
transaction.on_commit(self.release)
else:
self.release()
def __call__(self, func): def __call__(self, func):
@wraps(func) @wraps(func)
@ -57,5 +74,17 @@ class DistributedLock(RedisLock):
# 要创建一个新的锁对象 # 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy): with self.__class__(**self.kwargs_copy):
return func(*args, **kwds) return func(*args, **kwds)
return inner return inner
def locked_by_me(self):
if self.locked():
if self.get_owner_id() == self.id:
return True
return False
def release(self):
try:
super().release()
except AcquireFailed as e:
if self._release_raise_exc:
raise e

View File

@ -1,131 +0,0 @@
from uuid import uuid4
from functools import wraps
from django.core.cache import cache
from django.db.transaction import atomic
from rest_framework.request import Request
from rest_framework.exceptions import NotAuthenticated
from orgs.utils import current_org
from common.exceptions import SomeoneIsDoingThis, Timeout
from common.utils.timezone import dt_formater, now
# Redis 中锁值得模板,该模板提供了很强的可读性,方便调试与排错
VALUE_TEMPLATE = '{stage}:{username}:{user_id}:{now}:{rand_str}'
# 锁的状态
DOING = 'doing' # 处理中,此状态的锁可以被干掉
COMMITING = 'commiting' # 提交事务中,此状态很重要,要确保事务在锁消失之前返回了,不要轻易删除该锁
client = cache.client.get_client(write=True)
"""
将锁的状态从 `doing` 切换到 `commiting`
KEYS[1] key
ARGV[1]: doingvalue
ARGV[2]: commitingvalue
ARGV[3]: timeout
"""
change_lock_state_to_commiting_lua = '''
if (redis.call("get", KEYS[1]) == ARGV[1])
then
return redis.call("set", KEYS[1], ARGV[2], "EX", ARGV[3], "XX")
else
return 0
end
'''
change_lock_state_to_commiting_lua_obj = client.register_script(change_lock_state_to_commiting_lua)
"""
释放锁两种`value`都要检查`doing``commiting`
KEYS[1] key
ARGV[1]: 两个 `value` 中的其中一个
ARGV[2]: 两个 `value` 中的其中一个
"""
release_lua = '''
if (redis.call("get",KEYS[1]) == ARGV[1] or redis.call("get",KEYS[1]) == ARGV[2])
then
return redis.call("del",KEYS[1])
else
return 0
end
'''
release_lua_obj = client.register_script(release_lua)
def acquire(key, value, timeout):
return client.set(key, value, ex=timeout, nx=True)
def get(key):
return client.get(key)
def change_lock_state_to_commiting(key, doingvalue, commitingvalue, timeout=600):
# 将锁的状态从 `doing` 切换到 `commiting`
return bool(change_lock_state_to_commiting_lua_obj(keys=(key,), args=(doingvalue, commitingvalue, timeout)))
def release(key, value1, value2):
# 释放锁,两种`value` `doing`和`commiting` 都要检查
return release_lua_obj(keys=(key,), args=(value1, value2))
def _generate_value(request: Request, stage=DOING):
# 不支持匿名用户
user = request.user
if user.is_anonymous:
raise NotAuthenticated
return VALUE_TEMPLATE.format(
stage=stage, username=user.username, user_id=user.id,
now=dt_formater(now()), rand_str=uuid4()
)
default_wait_msg = SomeoneIsDoingThis.default_detail
def org_level_transaction_lock(key, timeout=300, wait_msg=default_wait_msg):
"""
被装饰的 `View` 必须取消自身的 `ATOMIC_REQUESTS`因为该装饰器要有事务的完全控制权
[官网](https://docs.djangoproject.com/en/3.1/topics/db/transactions/#tying-transactions-to-http-requests)
1. 获取锁只有当锁对应的 `key` 不存在时成功获取`value` 设置为 `doing`
2. 开启事务本次请求的事务必须确保在这里开启
3. 执行 `View`
4. `View` 体执行结束未异常此时事务还未提交
5. 检查锁是否过时过时事务回滚不过时重新设置`key`延长`key`有效期已确保足够时间提交事务同时把`key`的状态改为`commiting`
6. 提交事务
7. 释放锁释放的时候会检查`doing``commiting`的值因为删除或者更改锁必须提供与当前锁的`value`相同的值确保不误删
[锁参考文章](http://doc.redisfans.com/string/set.html#id2)
"""
def decorator(fun):
@wraps(fun)
def wrapper(request, *args, **kwargs):
# `key`可能是组织相关的,如果是把组织`id`加上
_key = key.format(org_id=current_org.id)
doing_value = _generate_value(request)
commiting_value = _generate_value(request, stage=COMMITING)
try:
lock = acquire(_key, doing_value, timeout)
if not lock:
raise SomeoneIsDoingThis(detail=wait_msg)
with atomic(savepoint=False):
ret = fun(request, *args, **kwargs)
# 提交事务前,检查一下锁是否还在
# 锁在的话,更新锁的状态为 `commiting`,延长锁时间,确保事务提交
# 锁不在的话回滚
ok = change_lock_state_to_commiting(_key, doing_value, commiting_value)
if not ok:
# 超时或者被中断了
raise Timeout
return ret
finally:
# 释放锁,锁的两个值都要尝试,不确定异常是从什么位置抛出的
release(_key, commiting_value, doing_value)
return wrapper
return decorator

View File

@ -184,3 +184,8 @@ def org_aware_func(org_arg_name):
current_org = LocalProxy(get_current_org) current_org = LocalProxy(get_current_org)
def ensure_in_real_or_default_org():
if not current_org or current_org.is_root():
raise ValueError('You must in a real or default org!')

View File

@ -13,7 +13,7 @@ from applications.models import Application
from perms.utils.application.permission import ( from perms.utils.application.permission import (
get_application_system_users_id get_application_system_users_id
) )
from perms.api.asset.user_permission.mixin import ForAdminMixin, ForUserMixin from perms.api.asset.user_permission.mixin import RoleAdminMixin, RoleUserMixin
from common.permissions import IsOrgAdminOrAppUser from common.permissions import IsOrgAdminOrAppUser
from perms.hands import User, SystemUser from perms.hands import User, SystemUser
from perms import serializers from perms import serializers
@ -43,11 +43,11 @@ class GrantedApplicationSystemUsersMixin(ListAPIView):
return system_users return system_users
class UserGrantedApplicationSystemUsersApi(ForAdminMixin, GrantedApplicationSystemUsersMixin): class UserGrantedApplicationSystemUsersApi(RoleAdminMixin, GrantedApplicationSystemUsersMixin):
pass pass
class MyGrantedApplicationSystemUsersApi(ForUserMixin, GrantedApplicationSystemUsersMixin): class MyGrantedApplicationSystemUsersApi(RoleUserMixin, GrantedApplicationSystemUsersMixin):
pass pass

View File

@ -8,7 +8,7 @@ from applications.api.mixin import (
SerializeApplicationToTreeNodeMixin SerializeApplicationToTreeNodeMixin
) )
from perms import serializers from perms import serializers
from perms.api.asset.user_permission.mixin import ForAdminMixin, ForUserMixin from perms.api.asset.user_permission.mixin import RoleAdminMixin, RoleUserMixin
from perms.utils.application.user_permission import ( from perms.utils.application.user_permission import (
get_user_granted_all_applications get_user_granted_all_applications
) )
@ -34,11 +34,11 @@ class AllGrantedApplicationsMixin(CommonApiMixin, ListAPIView):
return queryset.only(*self.only_fields) return queryset.only(*self.only_fields)
class UserAllGrantedApplicationsApi(ForAdminMixin, AllGrantedApplicationsMixin): class UserAllGrantedApplicationsApi(RoleAdminMixin, AllGrantedApplicationsMixin):
pass pass
class MyAllGrantedApplicationsApi(ForUserMixin, AllGrantedApplicationsMixin): class MyAllGrantedApplicationsApi(RoleUserMixin, AllGrantedApplicationsMixin):
pass pass

View File

@ -4,37 +4,23 @@ from rest_framework.request import Request
from common.permissions import IsOrgAdminOrAppUser, IsValidUser from common.permissions import IsOrgAdminOrAppUser, IsValidUser
from common.utils import lazyproperty from common.utils import lazyproperty
from common.http import is_true
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org
from users.models import User from users.models import User
from perms.models import UserGrantedMappingNode from perms.utils.asset.user_permission import UserGrantedTreeRefreshController
class UserNodeGrantStatusDispatchMixin: class PermBaseMixin:
user: User
@staticmethod def get(self, request, *args, **kwargs):
def get_mapping_node_by_key(key, user): force = is_true(request.query_params.get('rebuild_tree'))
return UserGrantedMappingNode.objects.get(key=key, user=user) controller = UserGrantedTreeRefreshController(self.user)
controller.refresh_if_need(force)
def dispatch_get_data(self, key, user): return super().get(request, *args, **kwargs)
status = UserGrantedMappingNode.get_node_granted_status(key, user)
if status == UserGrantedMappingNode.GRANTED_DIRECT:
return self.get_data_on_node_direct_granted(key)
elif status == UserGrantedMappingNode.GRANTED_INDIRECT:
return self.get_data_on_node_indirect_granted(key)
else:
return self.get_data_on_node_not_granted(key)
def get_data_on_node_direct_granted(self, key):
raise NotImplementedError
def get_data_on_node_indirect_granted(self, key):
raise NotImplementedError
def get_data_on_node_not_granted(self, key):
raise NotImplementedError
class ForAdminMixin: class RoleAdminMixin(PermBaseMixin):
permission_classes = (IsOrgAdminOrAppUser,) permission_classes = (IsOrgAdminOrAppUser,)
kwargs: dict kwargs: dict
@ -44,7 +30,7 @@ class ForAdminMixin:
return User.objects.get(id=user_id) return User.objects.get(id=user_id)
class ForUserMixin: class RoleUserMixin(PermBaseMixin):
permission_classes = (IsValidUser,) permission_classes = (IsValidUser,)
request: Request request: Request

View File

@ -1,156 +0,0 @@
# -*- coding: utf-8 -*-
#
from perms.api.asset.user_permission.mixin import UserNodeGrantStatusDispatchMixin
from rest_framework.generics import ListAPIView
from rest_framework.response import Response
from rest_framework.request import Request
from django.conf import settings
from assets.api.mixin import SerializeToTreeNodeMixin
from common.utils import get_logger
from perms.pagination import GrantedAssetLimitOffsetPagination
from assets.models import Asset, Node, FavoriteAsset
from perms import serializers
from perms.utils.asset.user_permission import (
get_node_all_granted_assets, get_user_direct_granted_assets,
get_user_granted_all_assets
)
from .mixin import ForAdminMixin, ForUserMixin
logger = get_logger(__name__)
class UserDirectGrantedAssetsApi(ListAPIView):
"""
用户直接授权的资产的列表也就是授权规则上直接授权的资产并非是来自节点的
"""
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
filterset_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
user = self.user
assets = get_user_direct_granted_assets(user)\
.prefetch_related('platform')\
.only(*self.only_fields)
return assets
class UserFavoriteGrantedAssetsApi(ListAPIView):
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
filterset_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
user = self.user
assets = FavoriteAsset.get_user_favorite_assets(user)\
.prefetch_related('platform')\
.only(*self.only_fields)
return assets
class AssetsAsTreeMixin(SerializeToTreeNodeMixin):
"""
资产 序列化成树的结构返回
"""
def list(self, request: Request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
if request.query_params.get('search'):
# 如果用户搜索的条件不精准,会导致返回大量的无意义数据。
# 这里限制一下返回数据的最大条数
queryset = queryset[:999]
data = self.serialize_assets(queryset, None)
return Response(data=data)
class UserDirectGrantedAssetsForAdminApi(ForAdminMixin, UserDirectGrantedAssetsApi):
pass
class MyDirectGrantedAssetsApi(ForUserMixin, UserDirectGrantedAssetsApi):
pass
class UserFavoriteGrantedAssetsForAdminApi(ForAdminMixin, UserFavoriteGrantedAssetsApi):
pass
class MyFavoriteGrantedAssetsApi(ForUserMixin, UserFavoriteGrantedAssetsApi):
pass
class UserDirectGrantedAssetsAsTreeForAdminApi(ForAdminMixin, AssetsAsTreeMixin, UserDirectGrantedAssetsApi):
pass
class MyUngroupAssetsAsTreeApi(ForUserMixin, AssetsAsTreeMixin, UserDirectGrantedAssetsApi):
def get_queryset(self):
queryset = super().get_queryset()
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
queryset = queryset.none()
return queryset
class UserAllGrantedAssetsApi(ForAdminMixin, ListAPIView):
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
serializer_class = serializers.AssetGrantedSerializer
filterset_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
queryset = get_user_granted_all_assets(self.user)
queryset = queryset.prefetch_related('platform')
return queryset.only(*self.only_fields)
class MyAllGrantedAssetsApi(ForUserMixin, UserAllGrantedAssetsApi):
pass
class MyAllAssetsAsTreeApi(ForUserMixin, AssetsAsTreeMixin, UserAllGrantedAssetsApi):
search_fields = ['hostname', 'ip']
class UserGrantedNodeAssetsApi(UserNodeGrantStatusDispatchMixin, ListAPIView):
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
filterset_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
pagination_class = GrantedAssetLimitOffsetPagination
pagination_node: Node
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
node_id = self.kwargs.get("node_id")
node = Node.objects.get(id=node_id)
self.pagination_node = node
return self.dispatch_get_data(node.key, self.user)
def get_data_on_node_direct_granted(self, key):
# 如果这个节点是直接授权的(或者说祖先节点直接授权的), 获取下面的所有资产
return Node.get_node_all_assets_by_key_v2(key)
def get_data_on_node_indirect_granted(self, key):
self.pagination_node = self.get_mapping_node_by_key(key, self.user)
return get_node_all_granted_assets(self.user, key)
def get_data_on_node_not_granted(self, key):
return Asset.objects.none()
class UserGrantedNodeAssetsForAdminApi(ForAdminMixin, UserGrantedNodeAssetsApi):
pass
class MyGrantedNodeAssetsApi(ForUserMixin, UserGrantedNodeAssetsApi):
pass

View File

@ -0,0 +1 @@
from .views import *

View File

@ -0,0 +1,127 @@
from rest_framework.response import Response
from rest_framework.request import Request
from users.models import User
from assets.api.mixin import SerializeToTreeNodeMixin
from common.utils import get_logger
from perms.pagination import NodeGrantedAssetPagination, AllGrantedAssetPagination
from assets.models import Asset, Node
from perms import serializers
from perms.utils.asset.user_permission import UserGrantedAssetsQueryUtils, QuerySetStage
logger = get_logger(__name__)
# 获取数据的 ------------------------------------------------------------
class UserDirectGrantedAssetsQuerysetMixin:
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
user: User
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
user = self.user
assets = UserGrantedAssetsQueryUtils(user) \
.get_direct_granted_assets() \
.prefetch_related('platform') \
.only(*self.only_fields)
return assets
class UserAllGrantedAssetsQuerysetMixin:
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
pagination_class = AllGrantedAssetPagination
user: User
def get_union_queryset(self, qs_stage: QuerySetStage):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
qs_stage.prefetch_related('platform').only(*self.only_fields)
queryset = UserGrantedAssetsQueryUtils(self.user) \
.get_all_granted_assets(qs_stage)
return queryset
class UserFavoriteGrantedAssetsMixin:
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
user: User
def get_union_queryset(self, qs_stage: QuerySetStage):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
user = self.user
qs_stage.prefetch_related('platform').only(*self.only_fields)
utils = UserGrantedAssetsQueryUtils(user)
assets = utils.get_favorite_assets(qs_stage=qs_stage)
return assets
class UserGrantedNodeAssetsMixin:
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
pagination_class = NodeGrantedAssetPagination
pagination_node: Node
user: User
def get_union_queryset(self, qs_stage: QuerySetStage):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
node_id = self.kwargs.get("node_id")
qs_stage.prefetch_related('platform').only(*self.only_fields)
node, assets = UserGrantedAssetsQueryUtils(self.user).get_node_all_assets(
node_id, qs_stage=qs_stage
)
self.pagination_node = node
return assets
# 控制格式的 ----------------------------------------------------
class AssetsUnionQuerysetMixin:
def get_queryset_union_prefer(self):
if hasattr(self, 'get_union_queryset'):
# 为了支持 union 查询
queryset = Asset.objects.all().distinct()
queryset = self.filter_queryset(queryset)
qs_stage = QuerySetStage()
qs_stage.and_with_queryset(queryset)
queryset = self.get_union_queryset(qs_stage)
else:
queryset = self.filter_queryset(self.get_queryset())
return queryset
class AssetsSerializerFormatMixin(AssetsUnionQuerysetMixin):
serializer_class = serializers.AssetGrantedSerializer
filterset_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def list(self, request, *args, **kwargs):
queryset = self.get_queryset_union_prefer()
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
class AssetsTreeFormatMixin(AssetsUnionQuerysetMixin, SerializeToTreeNodeMixin):
"""
资产 序列化成树的结构返回
"""
def list(self, request: Request, *args, **kwargs):
queryset = self.get_queryset_union_prefer()
if request.query_params.get('search'):
# 如果用户搜索的条件不精准,会导致返回大量的无意义数据。
# 这里限制一下返回数据的最大条数
queryset = queryset[:999]
data = self.serialize_assets(queryset, None)
return Response(data=data)
# def get_serializer_class(self):
# return EmptySerializer

View File

@ -0,0 +1,99 @@
from rest_framework.generics import ListAPIView
from django.conf import settings
from common.utils import get_logger
from ..mixin import RoleAdminMixin, RoleUserMixin
from .mixin import (
UserAllGrantedAssetsQuerysetMixin, UserDirectGrantedAssetsQuerysetMixin, UserFavoriteGrantedAssetsMixin,
UserGrantedNodeAssetsMixin, AssetsSerializerFormatMixin, AssetsTreeFormatMixin,
)
__all__ = [
'UserDirectGrantedAssetsForAdminApi', 'MyDirectGrantedAssetsApi', 'UserFavoriteGrantedAssetsForAdminApi',
'MyFavoriteGrantedAssetsApi', 'UserDirectGrantedAssetsAsTreeForAdminApi', 'MyUngroupAssetsAsTreeApi',
'UserAllGrantedAssetsApi', 'MyAllGrantedAssetsApi', 'MyAllAssetsAsTreeApi', 'UserGrantedNodeAssetsForAdminApi',
'MyGrantedNodeAssetsApi',
]
logger = get_logger(__name__)
class UserDirectGrantedAssetsForAdminApi(UserDirectGrantedAssetsQuerysetMixin,
RoleAdminMixin,
AssetsSerializerFormatMixin,
ListAPIView):
pass
class MyDirectGrantedAssetsApi(UserDirectGrantedAssetsQuerysetMixin,
RoleUserMixin,
AssetsSerializerFormatMixin,
ListAPIView):
pass
class UserFavoriteGrantedAssetsForAdminApi(UserFavoriteGrantedAssetsMixin,
RoleAdminMixin,
AssetsSerializerFormatMixin,
ListAPIView):
pass
class MyFavoriteGrantedAssetsApi(UserFavoriteGrantedAssetsMixin,
RoleUserMixin,
AssetsSerializerFormatMixin,
ListAPIView):
pass
class UserDirectGrantedAssetsAsTreeForAdminApi(UserDirectGrantedAssetsQuerysetMixin,
RoleAdminMixin,
AssetsTreeFormatMixin,
ListAPIView):
pass
class MyUngroupAssetsAsTreeApi(UserDirectGrantedAssetsQuerysetMixin,
RoleUserMixin,
AssetsTreeFormatMixin,
ListAPIView):
def get_queryset(self):
queryset = super().get_queryset()
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
queryset = queryset.none()
return queryset
class UserAllGrantedAssetsApi(UserAllGrantedAssetsQuerysetMixin,
RoleAdminMixin,
AssetsSerializerFormatMixin,
ListAPIView):
pass
class MyAllGrantedAssetsApi(UserAllGrantedAssetsQuerysetMixin,
RoleUserMixin,
AssetsSerializerFormatMixin,
ListAPIView):
pass
class MyAllAssetsAsTreeApi(UserAllGrantedAssetsQuerysetMixin,
RoleUserMixin,
AssetsTreeFormatMixin,
ListAPIView):
search_fields = ['hostname', 'ip']
class UserGrantedNodeAssetsForAdminApi(UserGrantedNodeAssetsMixin,
RoleAdminMixin,
AssetsSerializerFormatMixin,
ListAPIView):
pass
class MyGrantedNodeAssetsApi(UserGrantedNodeAssetsMixin,
RoleUserMixin,
AssetsSerializerFormatMixin,
ListAPIView):
pass

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import abc import abc
from django.conf import settings
from rest_framework.generics import ( from rest_framework.generics import (
ListAPIView ListAPIView
) )
@ -10,16 +9,11 @@ from rest_framework.request import Request
from assets.api.mixin import SerializeToTreeNodeMixin from assets.api.mixin import SerializeToTreeNodeMixin
from common.utils import get_logger from common.utils import get_logger
from .mixin import ForAdminMixin, ForUserMixin, UserNodeGrantStatusDispatchMixin from .mixin import RoleAdminMixin, RoleUserMixin
from perms.hands import Node, User from perms.hands import User
from perms import serializers from perms import serializers
from perms.utils.asset.user_permission import (
get_indirect_granted_node_children, from perms.utils.asset.user_permission import UserGrantedNodesQueryUtils
get_user_granted_nodes_list_via_mapping_node,
get_top_level_granted_nodes,
rebuild_user_tree_if_need, get_favorite_node,
get_ungrouped_node
)
logger = get_logger(__name__) logger = get_logger(__name__)
@ -61,7 +55,6 @@ class BaseGrantedNodeApi(_GrantedNodeStructApi, metaclass=abc.ABCMeta):
serializer_class = serializers.NodeGrantedSerializer serializer_class = serializers.NodeGrantedSerializer
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
rebuild_user_tree_if_need(request, self.user)
nodes = self.get_nodes() nodes = self.get_nodes()
serializer = self.get_serializer(nodes, many=True) serializer = self.get_serializer(nodes, many=True)
return Response(serializer.data) return Response(serializer.data)
@ -73,7 +66,6 @@ class BaseNodeChildrenApi(NodeChildrenMixin, BaseGrantedNodeApi, metaclass=abc.A
class BaseGrantedNodeAsTreeApi(SerializeToTreeNodeMixin, _GrantedNodeStructApi, metaclass=abc.ABCMeta): class BaseGrantedNodeAsTreeApi(SerializeToTreeNodeMixin, _GrantedNodeStructApi, metaclass=abc.ABCMeta):
def list(self, request: Request, *args, **kwargs): def list(self, request: Request, *args, **kwargs):
rebuild_user_tree_if_need(request, self.user)
nodes = self.get_nodes() nodes = self.get_nodes()
nodes = self.serialize_nodes(nodes, with_asset_amount=True) nodes = self.serialize_nodes(nodes, with_asset_amount=True)
return Response(data=nodes) return Response(data=nodes)
@ -83,30 +75,16 @@ class BaseNodeChildrenAsTreeApi(NodeChildrenMixin, BaseGrantedNodeAsTreeApi, met
pass pass
class UserGrantedNodeChildrenMixin(UserNodeGrantStatusDispatchMixin): class UserGrantedNodeChildrenMixin:
user: User user: User
request: Request request: Request
def get_children(self): def get_children(self):
user = self.user user = self.user
key = self.request.query_params.get('key') key = self.request.query_params.get('key')
nodes = UserGrantedNodesQueryUtils(user).get_node_children(key)
if not key:
nodes = list(get_top_level_granted_nodes(user))
else:
nodes = self.dispatch_get_data(key, user)
return nodes return nodes
def get_data_on_node_direct_granted(self, key):
return Node.objects.filter(parent_key=key)
def get_data_on_node_indirect_granted(self, key):
nodes = get_indirect_granted_node_children(self.user, key)
return nodes
def get_data_on_node_not_granted(self, key):
return Node.objects.none()
class UserGrantedNodesMixin: class UserGrantedNodesMixin:
""" """
@ -115,41 +93,38 @@ class UserGrantedNodesMixin:
user: User user: User
def get_nodes(self): def get_nodes(self):
nodes = [] utils = UserGrantedNodesQueryUtils(self.user)
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: nodes = utils.get_whole_tree_nodes()
nodes.append(get_ungrouped_node(self.user))
nodes.append(get_favorite_node(self.user))
nodes.extend(get_user_granted_nodes_list_via_mapping_node(self.user))
return nodes return nodes
# ------------------------------------------ # ------------------------------------------
# 最终的 api # 最终的 api
class UserGrantedNodeChildrenForAdminApi(ForAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi): class UserGrantedNodeChildrenForAdminApi(RoleAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi):
pass pass
class MyGrantedNodeChildrenApi(ForUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi): class MyGrantedNodeChildrenApi(RoleUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi):
pass pass
class UserGrantedNodeChildrenAsTreeForAdminApi(ForAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi): class UserGrantedNodeChildrenAsTreeForAdminApi(RoleAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi):
pass pass
class MyGrantedNodeChildrenAsTreeApi(ForUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi): class MyGrantedNodeChildrenAsTreeApi(RoleUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi):
pass pass
class UserGrantedNodesForAdminApi(ForAdminMixin, UserGrantedNodesMixin, BaseGrantedNodeApi): class UserGrantedNodesForAdminApi(RoleAdminMixin, UserGrantedNodesMixin, BaseGrantedNodeApi):
pass pass
class MyGrantedNodesApi(ForUserMixin, UserGrantedNodesMixin, BaseGrantedNodeApi): class MyGrantedNodesApi(RoleUserMixin, UserGrantedNodesMixin, BaseGrantedNodeApi):
pass pass
class MyGrantedNodesAsTreeApi(ForUserMixin, UserGrantedNodesMixin, BaseGrantedNodeAsTreeApi): class MyGrantedNodesAsTreeApi(RoleUserMixin, UserGrantedNodesMixin, BaseGrantedNodeAsTreeApi):
pass pass
# ------------------------------------------ # ------------------------------------------

View File

@ -1,29 +1,23 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from itertools import chain
from rest_framework.generics import ListAPIView from rest_framework.generics import ListAPIView
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from django.db.models import F, Value, CharField, Q from django.db.models import F, Value, CharField
from django.conf import settings from django.conf import settings
from common.utils.common import timeit
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org
from common.permissions import IsValidUser from common.permissions import IsValidUser
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from .mixin import UserNodeGrantStatusDispatchMixin, ForUserMixin, ForAdminMixin from .mixin import RoleUserMixin, RoleAdminMixin
from perms.utils.asset.user_permission import ( from perms.utils.asset.user_permission import (
get_indirect_granted_node_children, UNGROUPED_NODE_KEY, FAVORITE_NODE_KEY, UserGrantedTreeBuildUtils, get_user_all_asset_perm_ids,
get_user_direct_granted_assets, get_top_level_granted_nodes, UserGrantedNodesQueryUtils, UserGrantedAssetsQueryUtils,
get_user_granted_nodes_list_via_mapping_node, QuerySetStage,
get_user_granted_all_assets, rebuild_user_tree_if_need,
get_user_all_assetpermissions_id, get_favorite_node,
get_ungrouped_node, compute_tmp_mapping_node_from_perm,
TMP_GRANTED_FIELD, count_direct_granted_node_assets,
count_node_all_granted_assets
) )
from perms.models import AssetPermission from perms.models import AssetPermission, PermNode
from assets.models import Asset, FavoriteAsset from assets.models import Asset
from assets.api import SerializeToTreeNodeMixin from assets.api import SerializeToTreeNodeMixin
from perms.hands import Node from perms.hands import Node
@ -33,76 +27,45 @@ logger = get_logger(__name__)
class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView): class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
permission_classes = (IsValidUser,) permission_classes = (IsValidUser,)
def add_ungrouped_resource(self, data: list, user, asset_perms_id): @timeit
def add_ungrouped_resource(self, data: list, nodes_query_utils, assets_query_utils):
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return return
ungrouped_node = nodes_query_utils.get_ungrouped_node()
ungrouped_node = get_ungrouped_node(user, asset_perms_id=asset_perms_id) direct_granted_assets = assets_query_utils.get_direct_granted_assets().annotate(
direct_granted_assets = get_user_direct_granted_assets(
user, asset_perms_id=asset_perms_id
).annotate(
parent_key=Value(ungrouped_node.key, output_field=CharField()) parent_key=Value(ungrouped_node.key, output_field=CharField())
).prefetch_related('platform') ).prefetch_related('platform')
data.extend(self.serialize_nodes([ungrouped_node], with_asset_amount=True)) data.extend(self.serialize_nodes([ungrouped_node], with_asset_amount=True))
data.extend(self.serialize_assets(direct_granted_assets)) data.extend(self.serialize_assets(direct_granted_assets))
def add_favorite_resource(self, data: list, user, asset_perms_id): @timeit
favorite_node = get_favorite_node(user, asset_perms_id) def add_favorite_resource(self, data: list, nodes_query_utils, assets_query_utils):
favorite_assets = FavoriteAsset.get_user_favorite_assets( favorite_node = nodes_query_utils.get_favorite_node()
user, asset_perms_id=asset_perms_id
).annotate( qs_state = QuerySetStage().annotate(
parent_key=Value(favorite_node.key, output_field=CharField()) parent_key=Value(favorite_node.key, output_field=CharField())
).prefetch_related('platform') ).prefetch_related('platform')
favorite_assets = assets_query_utils.get_favorite_assets(qs_stage=qs_state, only=())
data.extend(self.serialize_nodes([favorite_node], with_asset_amount=True)) data.extend(self.serialize_nodes([favorite_node], with_asset_amount=True))
data.extend(self.serialize_assets(favorite_assets)) data.extend(self.serialize_assets(favorite_assets))
@timeit
def add_node_filtered_by_system_user(self, data: list, user, asset_perms_id): def add_node_filtered_by_system_user(self, data: list, user, asset_perms_id):
tmp_nodes = compute_tmp_mapping_node_from_perm(user, asset_perms_id=asset_perms_id) utils = UserGrantedTreeBuildUtils(user, asset_perms_id)
granted_nodes_key = [] nodes = utils.get_whole_tree_nodes()
for _node in tmp_nodes: data.extend(self.serialize_nodes(nodes, with_asset_amount=True))
_granted = getattr(_node, TMP_GRANTED_FIELD, False)
if not _granted:
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
assets_amount = count_direct_granted_node_assets(user, _node.key, asset_perms_id)
else:
assets_amount = count_node_all_granted_assets(user, _node.key, asset_perms_id)
_node.assets_amount = assets_amount
else:
granted_nodes_key.append(_node.key)
# 查询他们的子节点 def add_assets(self, data: list, assets_query_utils: UserGrantedAssetsQueryUtils):
q = Q() qs_stage = QuerySetStage().annotate(parent_key=F('nodes__key')).prefetch_related('platform')
for _key in granted_nodes_key:
q |= Q(key__startswith=f'{_key}:')
if q:
descendant_nodes = Node.objects.filter(q).distinct()
else:
descendant_nodes = Node.objects.none()
data.extend(self.serialize_nodes(chain(tmp_nodes, descendant_nodes), with_asset_amount=True))
def add_assets(self, data: list, user, asset_perms_id):
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
all_assets = get_user_granted_all_assets( all_assets = assets_query_utils.get_direct_granted_nodes_assets(qs_stage=qs_stage)
user,
via_mapping_node=False,
include_direct_granted_assets=False,
asset_perms_id=asset_perms_id
)
else: else:
all_assets = get_user_granted_all_assets( all_assets = assets_query_utils.get_all_granted_assets(qs_stage=qs_stage)
user,
via_mapping_node=False,
include_direct_granted_assets=True,
asset_perms_id=asset_perms_id
)
all_assets = all_assets.annotate(
parent_key=F('nodes__key')
).prefetch_related('platform')
data.extend(self.serialize_assets(all_assets)) data.extend(self.serialize_assets(all_assets))
@tmp_to_root_org() @tmp_to_root_org()
@ -117,7 +80,7 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
user = request.user user = request.user
data = [] data = []
asset_perms_id = get_user_all_assetpermissions_id(user) asset_perms_id = get_user_all_asset_perm_ids(user)
system_user_id = request.query_params.get('system_user') system_user_id = request.query_params.get('system_user')
if system_user_id: if system_user_id:
@ -125,89 +88,72 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
id__in=asset_perms_id, system_users__id=system_user_id, actions__gt=0 id__in=asset_perms_id, system_users__id=system_user_id, actions__gt=0
).values_list('id', flat=True).distinct()) ).values_list('id', flat=True).distinct())
self.add_ungrouped_resource(data, user, asset_perms_id) nodes_query_utils = UserGrantedNodesQueryUtils(user, asset_perms_id)
self.add_favorite_resource(data, user, asset_perms_id) assets_query_utils = UserGrantedAssetsQueryUtils(user, asset_perms_id)
self.add_ungrouped_resource(data, nodes_query_utils, assets_query_utils)
self.add_favorite_resource(data, nodes_query_utils, assets_query_utils)
if system_user_id: if system_user_id:
# 有系统用户筛选的需要重新计算树结构
self.add_node_filtered_by_system_user(data, user, asset_perms_id) self.add_node_filtered_by_system_user(data, user, asset_perms_id)
else: else:
rebuild_user_tree_if_need(request, user) all_nodes = nodes_query_utils.get_whole_tree_nodes(with_special=False)
all_nodes = get_user_granted_nodes_list_via_mapping_node(user)
data.extend(self.serialize_nodes(all_nodes, with_asset_amount=True)) data.extend(self.serialize_nodes(all_nodes, with_asset_amount=True))
self.add_assets(data, user, asset_perms_id) self.add_assets(data, assets_query_utils)
return Response(data=data) return Response(data=data)
class GrantedNodeChildrenWithAssetsAsTreeApiMixin(UserNodeGrantStatusDispatchMixin, class GrantedNodeChildrenWithAssetsAsTreeApiMixin(SerializeToTreeNodeMixin,
SerializeToTreeNodeMixin,
ListAPIView): ListAPIView):
""" """
带资产的授权树 带资产的授权树
""" """
user: None user: None
def get_data_on_node_direct_granted(self, key): def ensure_key(self):
nodes = Node.objects.filter(parent_key=key) key = self.request.query_params.get('key', None)
assets = Asset.org_objects.filter(nodes__key=key).distinct() id = self.request.query_params.get('id', None)
assets = assets.prefetch_related('platform')
return nodes, assets
def get_data_on_node_indirect_granted(self, key): if key is not None:
user = self.user return key
asset_perms_id = get_user_all_assetpermissions_id(user)
nodes = get_indirect_granted_node_children(user, key) node = get_object_or_none(Node, id=id)
if node:
assets = Asset.org_objects.filter( return node.key
nodes__key=key,
).filter(
granted_by_permissions__id__in=asset_perms_id
).distinct()
assets = assets.prefetch_related('platform')
return nodes, assets
def get_data_on_node_not_granted(self, key):
return Node.objects.none(), Asset.objects.none()
def get_data(self, key, user):
assets, nodes = [], []
if not key:
root_nodes = get_top_level_granted_nodes(user)
nodes.extend(root_nodes)
elif key == UNGROUPED_NODE_KEY:
assets = get_user_direct_granted_assets(user)
assets = assets.prefetch_related('platform')
elif key == FAVORITE_NODE_KEY:
assets = FavoriteAsset.get_user_favorite_assets(user)
else:
nodes, assets = self.dispatch_get_data(key, user)
return nodes, assets
def id2key_if_have(self):
id = self.request.query_params.get('id')
if id is not None:
node = get_object_or_none(Node, id=id)
if node:
return node.key
def list(self, request: Request, *args, **kwargs): def list(self, request: Request, *args, **kwargs):
key = self.request.query_params.get('key') user = self.user
if key is None: key = self.ensure_key()
key = self.id2key_if_have()
nodes_query_utils = UserGrantedNodesQueryUtils(user)
assets_query_utils = UserGrantedAssetsQueryUtils(user)
nodes = PermNode.objects.none()
assets = Asset.objects.none()
if not key:
nodes = nodes_query_utils.get_top_level_nodes()
elif key == PermNode.UNGROUPED_NODE_KEY:
assets = assets_query_utils.get_ungroup_assets()
elif key == PermNode.FAVORITE_NODE_KEY:
assets = assets_query_utils.get_favorite_assets()
else:
nodes = nodes_query_utils.get_node_children(key)
assets = assets_query_utils.get_node_assets(key)
assets = assets.prefetch_related('platform')
user = self.user user = self.user
rebuild_user_tree_if_need(request, user)
nodes, assets = self.get_data(key, user)
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True) tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, key) tree_assets = self.serialize_assets(assets, key)
return Response(data=[*tree_nodes, *tree_assets]) return Response(data=[*tree_nodes, *tree_assets])
class UserGrantedNodeChildrenWithAssetsAsTreeApi(ForAdminMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin): class UserGrantedNodeChildrenWithAssetsAsTreeApi(RoleAdminMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin):
pass pass
class MyGrantedNodeChildrenWithAssetsAsTreeApi(ForUserMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin): class MyGrantedNodeChildrenWithAssetsAsTreeApi(RoleUserMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin):
pass pass

View File

@ -1,10 +1,10 @@
from rest_framework import generics from rest_framework import generics
from django.db.models import Q
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from assets.models import SystemUser from assets.models import SystemUser
from common.permissions import IsValidUser from common.permissions import IsValidUser
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org
from perms.utils.asset.user_permission import get_user_all_asset_perm_ids
from .. import serializers from .. import serializers
@ -16,9 +16,9 @@ class SystemUserPermission(generics.ListAPIView):
def get_queryset(self): def get_queryset(self):
user = self.request.user user = self.request.user
asset_perms_id = get_user_all_asset_perm_ids(user)
queryset = SystemUser.objects.filter( queryset = SystemUser.objects.filter(
Q(granted_by_permissions__users=user) | granted_by_permissions__id__in=asset_perms_id
Q(granted_by_permissions__user_groups__users=user)
).distinct() ).distinct()
return queryset return queryset

View File

@ -1,47 +0,0 @@
from django.utils.crypto import get_random_string
from perms.utils import rebuild_user_mapping_nodes_if_need_with_lock
from common.thread_pools import SingletonThreadPoolExecutor
from common.utils import get_logger
from perms.models import RebuildUserTreeTask
logger = get_logger(__name__)
class Executor(SingletonThreadPoolExecutor):
pass
executor = Executor()
def run_mapping_node_tasks():
failed_user_ids = []
ident = get_random_string()
logger.debug(f'[{ident}]mapping_node_tasks running')
while True:
task = RebuildUserTreeTask.objects.exclude(
user_id__in=failed_user_ids
).first()
if task is None:
break
user = task.user
try:
rebuild_user_mapping_nodes_if_need_with_lock(user)
except:
logger.exception(f'[{ident}]mapping_node_tasks_exception')
failed_user_ids.append(user.id)
logger.debug(f'[{ident}]mapping_node_tasks finished')
def submit_update_mapping_node_task():
executor.submit(run_mapping_node_tasks)
def submit_update_mapping_node_task_for_user(user):
executor.submit(rebuild_user_mapping_nodes_if_need_with_lock, user)

11
apps/perms/locks.py Normal file
View File

@ -0,0 +1,11 @@
from common.utils.lock import DistributedLock
class UserGrantedTreeRebuildLock(DistributedLock):
name_template = 'perms.user.asset.node.tree.rebuid.<org_id:{org_id}>.<user_id:{user_id}>'
def __init__(self, org_id, user_id):
name = self.name_template.format(
org_id=org_id, user_id=user_id
)
super().__init__(name=name)

View File

@ -1,19 +1,6 @@
# Generated by Django 2.2.13 on 2020-08-21 08:20 # Generated by Django 2.2.13 on 2020-08-21 08:20
from django.db import migrations from django.db import migrations
from perms.tasks import dispatch_mapping_node_tasks
def start_build_users_perm_tree_task(apps, schema_editor):
User = apps.get_model('users', 'User')
RebuildUserTreeTask = apps.get_model('perms', 'RebuildUserTreeTask')
user_ids = User.objects.all().values_list('id', flat=True).distinct()
RebuildUserTreeTask.objects.bulk_create(
[RebuildUserTreeTask(user_id=i) for i in user_ids]
)
dispatch_mapping_node_tasks.delay()
class Migration(migrations.Migration): class Migration(migrations.Migration):
@ -23,5 +10,4 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.RunPython(start_build_users_perm_tree_task)
] ]

View File

@ -0,0 +1,65 @@
# Generated by Django 3.1 on 2021-02-04 09:49
import assets.models.node
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('assets', '0066_remove_node_assets_amount'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('perms', '0017_auto_20210104_0435'),
]
operations = [
migrations.CreateModel(
name='UserAssetGrantedTreeNodeRelation',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('org_id', models.CharField(blank=True, db_index=True, default='', max_length=36, verbose_name='Organization')),
('node_key', models.CharField(db_index=True, max_length=64, verbose_name='Key')),
('node_parent_key', models.CharField(db_index=True, default='', max_length=64, verbose_name='Parent key')),
('node_from', models.CharField(choices=[('granted', 'Direct node granted'), ('child', 'Have children node'), ('asset', 'Direct asset granted')], db_index=True, max_length=16)),
('node_assets_amount', models.IntegerField(default=0)),
('node', models.ForeignKey(db_constraint=False, default=None, on_delete=django.db.models.deletion.CASCADE, related_name='granted_node_rels', to='assets.node')),
('user', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
'abstract': False,
},
bases=(assets.models.node.FamilyMixin, models.Model),
),
migrations.RemoveField(
model_name='usergrantedmappingnode',
name='node',
),
migrations.RemoveField(
model_name='usergrantedmappingnode',
name='user',
),
migrations.CreateModel(
name='PermNode',
fields=[
],
options={
'ordering': [],
'proxy': True,
'indexes': [],
'constraints': [],
},
bases=('assets.node',),
),
migrations.DeleteModel(
name='RebuildUserTreeTask',
),
migrations.DeleteModel(
name='UserGrantedMappingNode',
),
]

View File

@ -2,7 +2,10 @@ import logging
from functools import reduce from functools import reduce
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.db.models import F
from common.db.models import ChoiceSet
from orgs.mixins.models import OrgModelMixin
from common.db import models from common.db import models
from common.utils import lazyproperty from common.utils import lazyproperty
from assets.models import Asset, SystemUser, Node, FamilyMixin from assets.models import Asset, SystemUser, Node, FamilyMixin
@ -11,7 +14,7 @@ from .base import BasePermission
__all__ = [ __all__ = [
'AssetPermission', 'Action', 'UserGrantedMappingNode', 'RebuildUserTreeTask', 'AssetPermission', 'Action', 'PermNode', 'UserAssetGrantedTreeNodeRelation',
] ]
# 使用场景 # 使用场景
@ -135,39 +138,109 @@ class AssetPermission(BasePermission):
from assets.models import Node from assets.models import Node
nodes_keys = self.nodes.all().values_list('key', flat=True) nodes_keys = self.nodes.all().values_list('key', flat=True)
assets_ids = set(self.assets.all().values_list('id', flat=True)) assets_ids = set(self.assets.all().values_list('id', flat=True))
nodes_assets_ids = Node.get_nodes_all_assets_ids(nodes_keys) nodes_assets_ids = Node.get_nodes_all_assets_ids_by_keys(nodes_keys)
assets_ids.update(nodes_assets_ids) assets_ids.update(nodes_assets_ids)
assets = Asset.objects.filter(id__in=assets_ids) assets = Asset.objects.filter(id__in=assets_ids)
return assets return assets
class UserAssetGrantedTreeNodeRelation(OrgModelMixin, FamilyMixin, models.JMSBaseModel):
class NodeFrom(ChoiceSet):
granted = 'granted', 'Direct node granted'
child = 'child', 'Have children node'
asset = 'asset', 'Direct asset granted'
class UserGrantedMappingNode(FamilyMixin, models.JMSBaseModel):
node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE,
db_constraint=False, null=True, related_name='mapping_nodes')
key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True) # '1:1:1:1'
user = models.ForeignKey('users.User', db_constraint=False, on_delete=models.CASCADE) user = models.ForeignKey('users.User', db_constraint=False, on_delete=models.CASCADE)
granted = models.BooleanField(default=False, db_index=True) node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE,
asset_granted = models.BooleanField(default=False, db_index=True) db_constraint=False, null=False, related_name='granted_node_rels')
parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'), db_index=True) # '1:1:1:1' node_key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True)
assets_amount = models.IntegerField(default=0) node_parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'), db_index=True)
node_from = models.CharField(choices=NodeFrom.choices, max_length=16, db_index=True)
node_assets_amount = models.IntegerField(default=0)
GRANTED_DIRECT = 1 @property
GRANTED_INDIRECT = 2 def key(self):
GRANTED_NONE = 0 return self.node_key
@property
def parent_key(self):
return self.node_parent_key
@classmethod @classmethod
def get_node_granted_status(cls, key, user): def get_node_granted_status(cls, user, key):
ancestor_keys = Node.get_node_ancestor_keys(key, with_self=True) ancestor_keys = set(cls.get_node_ancestor_keys(key, with_self=True))
has_granted = UserGrantedMappingNode.objects.filter( ancestor_rel_nodes = cls.objects.filter(user=user, node_key__in=ancestor_keys)
key__in=ancestor_keys, user=user
).values_list('granted', flat=True) for rel_node in ancestor_rel_nodes:
if not has_granted: if rel_node.key == key:
return cls.GRANTED_NONE return rel_node.node_from, rel_node
if any(list(has_granted)): if rel_node.node_from == cls.NodeFrom.granted:
return cls.GRANTED_DIRECT return cls.NodeFrom.granted, None
return cls.GRANTED_INDIRECT return '', None
class RebuildUserTreeTask(models.JMSBaseModel): class PermNode(Node):
user = models.ForeignKey('users.User', on_delete=models.CASCADE, verbose_name=_('User')) class Meta:
proxy = True
ordering = []
# 特殊节点
UNGROUPED_NODE_KEY = 'ungrouped'
UNGROUPED_NODE_VALUE = _('Ungrouped')
FAVORITE_NODE_KEY = 'favorite'
FAVORITE_NODE_VALUE = _('Favorite')
node_from = ''
granted_assets_amount = 0
# 提供可以设置 资产数量的字段
_assets_amount = None
annotate_granted_node_rel_fields = {
'granted_assets_amount': F('granted_node_rels__node_assets_amount'),
'node_from': F('granted_node_rels__node_from')
}
@property
def assets_amount(self):
_assets_amount = getattr(self, '_assets_amount')
if isinstance(_assets_amount, int):
return _assets_amount
return super().assets_amount
@assets_amount.setter
def assets_amount(self, value):
self._assets_amount = value
def use_granted_assets_amount(self):
self.assets_amount = self.granted_assets_amount
@classmethod
def get_ungrouped_node(cls, assets_amount):
return cls(
id=cls.UNGROUPED_NODE_KEY,
key=cls.UNGROUPED_NODE_KEY,
value=cls.UNGROUPED_NODE_VALUE,
assets_amount=assets_amount
)
@classmethod
def get_favorite_node(cls, assets_amount):
node = cls(
id=cls.FAVORITE_NODE_KEY,
key=cls.FAVORITE_NODE_KEY,
value=cls.FAVORITE_NODE_VALUE,
)
node.assets_amount = assets_amount
return node
def get_granted_status(self, user):
status, rel_node = UserAssetGrantedTreeNodeRelation.get_node_granted_status(user, self.key)
self.node_from = status
if rel_node:
self.granted_assets_amount = rel_node.node_assets_amount
return status
def save(self):
# 这是个只读 Model
raise NotImplementedError

View File

@ -1,30 +1,54 @@
from rest_framework.pagination import LimitOffsetPagination from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request from rest_framework.request import Request
from django.db.models import Sum
from perms.models import UserAssetGrantedTreeNodeRelation
from common.utils import get_logger from common.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
class GrantedAssetLimitOffsetPagination(LimitOffsetPagination): class GrantedAssetPaginationBase(LimitOffsetPagination):
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
self._user = request.user
return super().paginate_queryset(queryset, request, view=None)
def get_count(self, queryset): def get_count(self, queryset):
exclude_query_params = { exclude_query_params = {
self.limit_query_param, self.limit_query_param,
self.offset_query_param, self.offset_query_param,
'key', 'all', 'show_current_asset', 'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw' 'cache_policy', 'display', 'draw',
'order',
} }
for k, v in self._request.query_params.items(): for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None: if k not in exclude_query_params and v is not None:
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
return super().get_count(queryset) return super().get_count(queryset)
return self.get_count_from_nodes(queryset)
def get_count_from_nodes(self, queryset):
raise NotImplementedError
class NodeGrantedAssetPagination(GrantedAssetPaginationBase):
def get_count_from_nodes(self, queryset):
node = getattr(self._view, 'pagination_node', None) node = getattr(self._view, 'pagination_node', None)
if node: if node:
logger.debug(f'{self._request.get_full_path()} hit node.assets_amount[{node.assets_amount}]') logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}')
return node.assets_amount return node.assets_amount
else: else:
logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} not has `pagination_node` -> {self._request.get_full_path()}')
return super().get_count(queryset) return super().get_count(queryset)
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request class AllGrantedAssetPagination(GrantedAssetPaginationBase):
self._view = view def get_count_from_nodes(self, queryset):
return super().paginate_queryset(queryset, request, view=None) assets_amount = sum(UserAssetGrantedTreeNodeRelation.objects.filter(
user=self._user, node_parent_key=''
).values_list('node_assets_amount', flat=True))
logger.debug(f'Hit all assets amount {assets_amount} -> {self._request.get_full_path()}')
return assets_amount

View File

@ -0,0 +1,2 @@
from . import common
from . import refresh_perms

View File

@ -1,31 +1,22 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models.signals import m2m_changed, pre_delete, pre_save from django.db.models.signals import m2m_changed
from django.dispatch import receiver from django.dispatch import receiver
from perms.tasks import create_rebuild_user_tree_task, \
create_rebuild_user_tree_task_by_related_nodes_or_assets
from users.models import User, UserGroup from users.models import User, UserGroup
from assets.models import Asset, SystemUser from assets.models import SystemUser
from applications.models import Application from applications.models import Application
from common.utils import get_logger from common.utils import get_logger
from common.exceptions import M2MReverseNotAllowed from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR from common.const.signals import POST_ADD
from .models import AssetPermission, ApplicationPermission from perms.models import AssetPermission, ApplicationPermission
logger = get_logger(__file__) logger = get_logger(__file__)
def handle_rebuild_user_tree(instance, action, reverse, pk_set, **kwargs): @receiver(m2m_changed, sender=User.groups.through)
if action.startswith('post'): def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs):
if reverse:
create_rebuild_user_tree_task(pk_set)
else:
create_rebuild_user_tree_task([instance.id])
def handle_bind_groups_systemuser(instance, action, reverse, pk_set, **kwargs):
""" """
UserGroup 增加 User 增加的 User 需要与 UserGroup 关联的动态系统用户相关联 UserGroup 增加 User 增加的 User 需要与 UserGroup 关联的动态系统用户相关联
""" """
@ -47,53 +38,11 @@ def handle_bind_groups_systemuser(instance, action, reverse, pk_set, **kwargs):
system_user.users.add(*users_id) system_user.users.add(*users_id)
@receiver(m2m_changed, sender=User.groups.through)
def on_user_groups_change(**kwargs):
handle_rebuild_user_tree(**kwargs)
handle_bind_groups_systemuser(**kwargs)
@receiver([pre_save], sender=AssetPermission)
def on_asset_perm_deactive(instance: AssetPermission, **kwargs):
try:
old = AssetPermission.objects.only('is_active').get(id=instance.id)
if instance.is_active != old.is_active:
create_rebuild_user_tree_task_by_asset_perm(instance)
except AssetPermission.DoesNotExist:
pass
@receiver([pre_delete], sender=AssetPermission)
def on_asset_permission_delete(instance, **kwargs):
# 授权删除之前,查出所有相关用户
create_rebuild_user_tree_task_by_asset_perm(instance)
def create_rebuild_user_tree_task_by_asset_perm(asset_perm: AssetPermission):
user_ids = set()
user_ids.update(
UserGroup.objects.filter(
assetpermissions=asset_perm, users__id__isnull=False
).distinct().values_list('users__id', flat=True)
)
user_ids.update(
User.objects.filter(assetpermissions=asset_perm).distinct().values_list('id', flat=True)
)
create_rebuild_user_tree_task(user_ids)
def need_rebuild_mapping_node(action):
return action in (POST_REMOVE, POST_ADD, POST_CLEAR)
@receiver(m2m_changed, sender=AssetPermission.nodes.through) @receiver(m2m_changed, sender=AssetPermission.nodes.through)
def on_permission_nodes_changed(instance, action, reverse, pk_set, model, **kwargs): def on_permission_nodes_changed(instance, action, reverse, pk_set, model, **kwargs):
if reverse: if reverse:
raise M2MReverseNotAllowed raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
create_rebuild_user_tree_task_by_asset_perm(instance)
if action != POST_ADD: if action != POST_ADD:
return return
logger.debug("Asset permission nodes change signal received") logger.debug("Asset permission nodes change signal received")
@ -110,9 +59,6 @@ def on_permission_assets_changed(instance, action, reverse, pk_set, model, **kwa
if reverse: if reverse:
raise M2MReverseNotAllowed raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
create_rebuild_user_tree_task_by_asset_perm(instance)
if action != POST_ADD: if action != POST_ADD:
return return
logger.debug("Asset permission assets change signal received") logger.debug("Asset permission assets change signal received")
@ -150,9 +96,6 @@ def on_asset_permission_users_changed(instance, action, reverse, pk_set, model,
if reverse: if reverse:
raise M2MReverseNotAllowed raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
create_rebuild_user_tree_task(pk_set)
if action != POST_ADD: if action != POST_ADD:
return return
logger.debug("Asset permission users change signal received") logger.debug("Asset permission users change signal received")
@ -171,10 +114,6 @@ def on_asset_permission_user_groups_changed(instance, action, pk_set, model,
if reverse: if reverse:
raise M2MReverseNotAllowed raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
user_ids = User.objects.filter(groups__id__in=pk_set).distinct().values_list('id', flat=True)
create_rebuild_user_tree_task(user_ids)
if action != POST_ADD: if action != POST_ADD:
return return
logger.debug("Asset permission user groups change signal received") logger.debug("Asset permission user groups change signal received")
@ -187,21 +126,6 @@ def on_asset_permission_user_groups_changed(instance, action, pk_set, model,
system_user.groups.add(*tuple(groups)) system_user.groups.add(*tuple(groups))
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(action, instance, reverse, pk_set, **kwargs):
if not need_rebuild_mapping_node(action):
return
if reverse:
asset_pk_set = pk_set
node_pk_set = [instance.id]
else:
asset_pk_set = [instance.id]
node_pk_set = pk_set
create_rebuild_user_tree_task_by_related_nodes_or_assets.delay(node_pk_set, asset_pk_set)
@receiver(m2m_changed, sender=ApplicationPermission.system_users.through) @receiver(m2m_changed, sender=ApplicationPermission.system_users.through)
def on_application_permission_system_users_changed(sender, instance: ApplicationPermission, action, reverse, pk_set, **kwargs): def on_application_permission_system_users_changed(sender, instance: ApplicationPermission, action, reverse, pk_set, **kwargs):
if not instance.category_remote_app: if not instance.category_remote_app:

View File

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
#
from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save
from django.dispatch import receiver
from users.models import User
from assets.models import Asset
from orgs.utils import current_org
from common.utils import get_logger
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR
from perms.models import AssetPermission
from perms.utils.asset.user_permission import UserGrantedTreeRefreshController
logger = get_logger(__file__)
@receiver(m2m_changed, sender=User.groups.through)
def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs):
if action.startswith('post'):
if reverse:
group_ids = [instance.id]
user_ids = pk_set
else:
group_ids = pk_set
user_ids = [instance.id]
exists = AssetPermission.user_groups.through.objects.filter(usergroup_id__in=group_ids).exists()
if exists:
org_ids = [current_org.id]
UserGrantedTreeRefreshController.add_need_refresh_orgs_for_users(org_ids, user_ids)
@receiver([pre_delete], sender=AssetPermission)
def on_asset_perm_pre_delete(sender, instance, **kwargs):
# 授权删除之前,查出所有相关用户
UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id])
@receiver([pre_save], sender=AssetPermission)
def on_asset_perm_pre_save(sender, instance, **kwargs):
try:
old = AssetPermission.objects.get(id=instance.id)
if old.is_valid != instance.is_valid:
UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id])
except AssetPermission.DoesNotExist:
pass
@receiver([post_save], sender=AssetPermission)
def on_asset_perm_post_save(sender, instance, created, **kwargs):
if created:
UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id])
def need_rebuild_mapping_node(action):
return action in (POST_REMOVE, POST_ADD, POST_CLEAR)
@receiver(m2m_changed, sender=AssetPermission.nodes.through)
def on_permission_nodes_changed(sender, instance, action, reverse, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id])
@receiver(m2m_changed, sender=AssetPermission.assets.through)
def on_permission_assets_changed(sender, instance, action, reverse, pk_set, model, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids([instance.id])
@receiver(m2m_changed, sender=AssetPermission.users.through)
def on_asset_permission_users_changed(sender, action, reverse, pk_set, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
UserGrantedTreeRefreshController.add_need_refresh_orgs_for_users(
[current_org.id], pk_set
)
@receiver(m2m_changed, sender=AssetPermission.user_groups.through)
def on_asset_permission_user_groups_changed(sender, action, pk_set, reverse, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
user_ids = User.groups.through.objects.filter(usergroup_id__in=pk_set).distinct().values_list('user_id', flat=True)
UserGrantedTreeRefreshController.add_need_refresh_orgs_for_users(
[current_org.id], user_ids
)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(action, instance, reverse, pk_set, **kwargs):
if not need_rebuild_mapping_node(action):
return
if reverse:
asset_pk_set = pk_set
node_pk_set = [instance.id]
else:
asset_pk_set = [instance.id]
node_pk_set = pk_set
UserGrantedTreeRefreshController.add_need_refresh_on_nodes_assets_relate_change(node_pk_set, asset_pk_set)

View File

@ -2,39 +2,18 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
from datetime import timedelta from datetime import timedelta
from django.db import transaction
from django.db.models import Q
from django.db.transaction import atomic from django.db.transaction import atomic
from django.conf import settings from django.conf import settings
from celery import shared_task from celery import shared_task
from common.utils import get_logger from common.utils import get_logger
from common.utils.timezone import now, dt_formater, dt_parser from common.utils.timezone import now, dt_formater, dt_parser
from users.models import User
from ops.celery.decorator import register_as_period_task from ops.celery.decorator import register_as_period_task
from assets.models import Node from perms.models import AssetPermission
from perms.models import RebuildUserTreeTask, AssetPermission from perms.utils.asset.user_permission import UserGrantedTreeRefreshController
from perms.utils.asset.user_permission import rebuild_user_mapping_nodes_if_need_with_lock, lock
logger = get_logger(__file__) logger = get_logger(__file__)
@shared_task(queue='node_tree')
def rebuild_user_mapping_nodes_celery_task(user_id):
user = User.objects.get(id=user_id)
try:
rebuild_user_mapping_nodes_if_need_with_lock(user)
except lock.SomeoneIsDoingThis:
pass
@shared_task(queue='node_tree')
def dispatch_mapping_node_tasks():
user_ids = RebuildUserTreeTask.objects.all().values_list('user_id', flat=True).distinct()
logger.info(f'>>> dispatch_mapping_node_tasks for users {list(user_ids)}')
for id in user_ids:
rebuild_user_mapping_nodes_celery_task.delay(id)
@register_as_period_task(interval=settings.PERM_EXPIRED_CHECK_PERIODIC) @register_as_period_task(interval=settings.PERM_EXPIRED_CHECK_PERIODIC)
@shared_task(queue='celery_check_asset_perm_expired') @shared_task(queue='celery_check_asset_perm_expired')
@atomic() @atomic()
@ -60,66 +39,9 @@ def check_asset_permission_expired():
setting.value = dt_formater(end) setting.value = dt_formater(end)
setting.save() setting.save()
ids = AssetPermission.objects.filter( asset_perm_ids = AssetPermission.objects.filter(
date_expired__gte=start, date_expired__lte=end date_expired__gte=start, date_expired__lte=end
).distinct().values_list('id', flat=True) ).distinct().values_list('id', flat=True)
logger.info(f'>>> checking {start} to {end} have {ids} expired') asset_perm_ids = list(asset_perm_ids)
dispatch_process_expired_asset_permission.delay(list(ids)) logger.info(f'>>> checking {start} to {end} have {asset_perm_ids} expired')
UserGrantedTreeRefreshController.add_need_refresh_by_asset_perm_ids_cross_orgs(asset_perm_ids)
@shared_task(queue='node_tree')
def dispatch_process_expired_asset_permission(asset_perms_id):
user_ids = User.objects.filter(
Q(assetpermissions__id__in=asset_perms_id) |
Q(groups__assetpermissions__id__in=asset_perms_id)
).distinct().values_list('id', flat=True)
RebuildUserTreeTask.objects.bulk_create(
[RebuildUserTreeTask(user_id=user_id) for user_id in user_ids]
)
dispatch_mapping_node_tasks.delay()
def create_rebuild_user_tree_task(user_ids):
RebuildUserTreeTask.objects.bulk_create(
[RebuildUserTreeTask(user_id=i) for i in user_ids]
)
transaction.on_commit(dispatch_mapping_node_tasks.delay)
@shared_task(queue='node_tree')
def create_rebuild_user_tree_task_by_related_nodes_or_assets(node_ids, asset_ids):
node_ids = set(node_ids)
node_keys = set()
nodes = Node.objects.filter(id__in=node_ids)
for _node in nodes:
node_keys.update(_node.get_ancestor_keys())
node_ids.update(
Node.objects.filter(key__in=node_keys).values_list('id', flat=True)
)
asset_perms_id = set()
asset_perms_id.update(
AssetPermission.objects.filter(
assets__id__in=asset_ids
).values_list('id', flat=True).distinct()
)
asset_perms_id.update(
AssetPermission.objects.filter(
nodes__id__in=node_ids
).values_list('id', flat=True).distinct()
)
user_ids = set()
user_ids.update(
User.objects.filter(
assetpermissions__id__in=asset_perms_id
).distinct().values_list('id', flat=True)
)
user_ids.update(
User.objects.filter(
groups__assetpermissions__id__in=asset_perms_id
).distinct().values_list('id', flat=True)
)
create_rebuild_user_tree_task(user_ids)

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,6 @@ import forgery_py
from .base import FakeDataGenerator from .base import FakeDataGenerator
from assets.models import * from assets.models import *
from assets.utils import check_node_assets_amount
class AdminUsersGenerator(FakeDataGenerator): class AdminUsersGenerator(FakeDataGenerator):
@ -93,4 +92,4 @@ class AssetsGenerator(FakeDataGenerator):
self.set_assets_nodes(creates) self.set_assets_nodes(creates)
def after_generate(self): def after_generate(self):
check_node_assets_amount() pass