mirror of https://github.com/jumpserver/jumpserver
perf: 优化用户详情页授权列表加载速度&添加可重入锁
parent
e599bca951
commit
9be3cbb936
|
@ -4,9 +4,7 @@ from assets.api import FilterAssetByNodeMixin
|
|||
from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework.generics import RetrieveAPIView
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.decorators import method_decorator
|
||||
|
||||
from assets.locks import NodeTreeUpdateLock
|
||||
from common.utils import get_logger, get_object_or_none
|
||||
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser
|
||||
from orgs.mixins.api import OrgBulkModelViewSet
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import List
|
|||
|
||||
from common.utils.common import timeit
|
||||
from assets.models import Node, Asset
|
||||
from assets.pagination import AssetLimitOffsetPagination
|
||||
from assets.pagination import NodeAssetTreePagination
|
||||
from common.utils import lazyproperty
|
||||
from assets.utils import get_node, is_query_node_all_assets
|
||||
|
||||
|
@ -81,7 +81,7 @@ class SerializeToTreeNodeMixin:
|
|||
|
||||
|
||||
class FilterAssetByNodeMixin:
|
||||
pagination_class = AssetLimitOffsetPagination
|
||||
pagination_class = NodeAssetTreePagination
|
||||
|
||||
@lazyproperty
|
||||
def is_query_node_all_assets(self):
|
||||
|
|
|
@ -8,7 +8,6 @@ from rest_framework.response import Response
|
|||
from rest_framework.decorators import action
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.shortcuts import get_object_or_404, Http404
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.db.models.signals import m2m_changed
|
||||
|
||||
from common.const.http import POST
|
||||
|
@ -25,10 +24,10 @@ from ..models import Node
|
|||
from ..tasks import (
|
||||
update_node_assets_hardware_info_manual,
|
||||
test_node_assets_connectivity_manual,
|
||||
check_node_assets_amount_task
|
||||
)
|
||||
from .. import serializers
|
||||
from .mixin import SerializeToTreeNodeMixin
|
||||
from assets.locks import NodeTreeUpdateLock
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
@ -54,6 +53,11 @@ class NodeViewSet(OrgModelViewSet):
|
|||
serializer.validated_data["key"] = child_key
|
||||
serializer.save()
|
||||
|
||||
@action(methods=[POST], detail=False, url_path='check_assets_amount_task')
|
||||
def check_assets_amount_task(self, request):
|
||||
task = check_node_assets_amount_task.delay(current_org.id)
|
||||
return Response(data={'task': task.id})
|
||||
|
||||
def perform_update(self, serializer):
|
||||
node = self.get_object()
|
||||
if node.is_org_root() and node.value != serializer.validated_data['value']:
|
||||
|
|
|
@ -15,7 +15,6 @@ class NodeTreeUpdateLock(DistributedLock):
|
|||
)
|
||||
return name
|
||||
|
||||
def __init__(self, blocking=True):
|
||||
def __init__(self):
|
||||
name = self.get_name()
|
||||
super().__init__(name=name, blocking=blocking,
|
||||
release_lock_on_transaction_commit=True)
|
||||
super().__init__(name=name, release_on_transaction_commit=True, reentrant=True)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 3.1 on 2021-02-04 09:49
|
||||
# Generated by Django 3.1 on 2021-02-08 10:02
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
@ -10,8 +10,8 @@ class Migration(migrations.Migration):
|
|||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='node',
|
||||
name='assets_amount',
|
||||
migrations.AlterModelOptions(
|
||||
name='asset',
|
||||
options={'ordering': ['hostname'], 'verbose_name': 'Asset'},
|
||||
),
|
||||
]
|
|
@ -353,4 +353,4 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
|
|||
class Meta:
|
||||
unique_together = [('org_id', 'hostname')]
|
||||
verbose_name = _("Asset")
|
||||
ordering = ["hostname", "ip"]
|
||||
ordering = ["hostname", ]
|
||||
|
|
|
@ -425,11 +425,6 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
|
|||
node_ids.update(_ids)
|
||||
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
|
||||
|
||||
@property
|
||||
def assets_amount(self):
|
||||
assets_id = self.get_all_assets_id()
|
||||
return len(assets_id)
|
||||
|
||||
def get_all_assets_id(self):
|
||||
assets_id = self.get_all_assets_id_by_node_key(org_id=self.org_id, node_key=self.key)
|
||||
return set(assets_id)
|
||||
|
@ -550,6 +545,7 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
|
|||
date_create = models.DateTimeField(auto_now_add=True)
|
||||
parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"),
|
||||
db_index=True, default='')
|
||||
assets_amount = models.IntegerField(default=0)
|
||||
|
||||
objects = OrgManager.from_queryset(NodeQuerySet)()
|
||||
is_node = True
|
||||
|
|
|
@ -1,39 +1,51 @@
|
|||
from rest_framework.pagination import LimitOffsetPagination
|
||||
from rest_framework.request import Request
|
||||
|
||||
from common.utils import get_logger
|
||||
from assets.models import Node
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AssetPaginationBase(LimitOffsetPagination):
|
||||
|
||||
def init_attrs(self, queryset, request: Request, view=None):
|
||||
self._request = request
|
||||
self._view = view
|
||||
self._user = request.user
|
||||
|
||||
def paginate_queryset(self, queryset, request: Request, view=None):
|
||||
self.init_attrs(queryset, request, view)
|
||||
return super().paginate_queryset(queryset, request, view=None)
|
||||
|
||||
class AssetLimitOffsetPagination(LimitOffsetPagination):
|
||||
"""
|
||||
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
|
||||
"""
|
||||
def get_count(self, queryset):
|
||||
"""
|
||||
1. 如果查询节点下的所有资产,那 count 使用 Node.assets_amount
|
||||
2. 如果有其他过滤条件使用 super
|
||||
3. 如果只查询该节点下的资产使用 super
|
||||
"""
|
||||
exclude_query_params = {
|
||||
self.limit_query_param,
|
||||
self.offset_query_param,
|
||||
'node', 'all', 'show_current_asset',
|
||||
'node_id', 'display', 'draw', 'fields_size',
|
||||
'key', 'all', 'show_current_asset',
|
||||
'cache_policy', 'display', 'draw',
|
||||
'order', 'node', 'node_id', 'fields_size',
|
||||
}
|
||||
|
||||
for k, v in self._request.query_params.items():
|
||||
if k not in exclude_query_params and v is not None:
|
||||
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
|
||||
return super().get_count(queryset)
|
||||
node_assets_count = self.get_count_from_nodes(queryset)
|
||||
if node_assets_count is None:
|
||||
return super().get_count(queryset)
|
||||
return node_assets_count
|
||||
|
||||
def get_count_from_nodes(self, queryset):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NodeAssetTreePagination(AssetPaginationBase):
|
||||
def get_count_from_nodes(self, queryset):
|
||||
is_query_all = self._view.is_query_node_all_assets
|
||||
if is_query_all:
|
||||
node = self._view.node
|
||||
if not node:
|
||||
node = Node.org_root()
|
||||
logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}')
|
||||
return node.assets_amount
|
||||
return super().get_count(queryset)
|
||||
|
||||
def paginate_queryset(self, queryset, request: Request, view=None):
|
||||
self._request = request
|
||||
self._view = view
|
||||
return super().paginate_queryset(queryset, request, view=None)
|
||||
return None
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
from .common import *
|
||||
from .maintain_nodes_tree import *
|
||||
from .node_assets_amount import *
|
||||
from .node_assets_mapping import *
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from operator import add, sub
|
||||
from django.db.models import Q, F
|
||||
from django.dispatch import receiver
|
||||
from django.db.models.signals import (
|
||||
m2m_changed
|
||||
)
|
||||
|
||||
from orgs.utils import ensure_in_real_or_default_org
|
||||
from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR
|
||||
from common.utils import get_logger
|
||||
from assets.models import Asset, Node, compute_parent_key
|
||||
from assets.locks import NodeTreeUpdateLock
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=Asset.nodes.through)
|
||||
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
|
||||
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
|
||||
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
|
||||
refused = (PRE_CLEAR,)
|
||||
if action in refused:
|
||||
raise ValueError
|
||||
|
||||
mapper = {
|
||||
PRE_ADD: add,
|
||||
POST_REMOVE: sub
|
||||
}
|
||||
if action not in mapper:
|
||||
return
|
||||
|
||||
operator = mapper[action]
|
||||
|
||||
if reverse:
|
||||
node: Node = instance
|
||||
asset_pk_set = set(pk_set)
|
||||
NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator)
|
||||
else:
|
||||
asset_pk = instance.id
|
||||
# 与资产直接关联的节点
|
||||
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
|
||||
NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator)
|
||||
|
||||
|
||||
class NodeAssetsAmountUtils:
|
||||
|
||||
@classmethod
|
||||
def _remove_ancestor_keys(cls, ancestor_key, tree_set):
|
||||
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
|
||||
# 判断是否在集合里,来区分是否已被处理过
|
||||
while ancestor_key and ancestor_key in tree_set:
|
||||
tree_set.remove(ancestor_key)
|
||||
ancestor_key = compute_parent_key(ancestor_key)
|
||||
|
||||
@classmethod
|
||||
def _is_asset_exists_in_node(cls, asset_pk, node_key):
|
||||
exists = Asset.objects.filter(
|
||||
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
|
||||
).filter(id=asset_pk).exists()
|
||||
return exists
|
||||
|
||||
@classmethod
|
||||
@ensure_in_real_or_default_org
|
||||
@NodeTreeUpdateLock()
|
||||
def update_nodes_asset_amount(cls, node_keys, asset_pk, operator):
|
||||
"""
|
||||
一个资产与多个节点关系变化时,更新计数
|
||||
|
||||
:param node_keys: 节点 id 的集合
|
||||
:param asset_pk: 资产 id
|
||||
:param operator: 操作
|
||||
"""
|
||||
|
||||
# 所有相关节点的祖先节点,组成一棵局部树
|
||||
ancestor_keys = set()
|
||||
for key in node_keys:
|
||||
ancestor_keys.update(Node.get_node_ancestor_keys(key))
|
||||
|
||||
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
|
||||
node_keys -= ancestor_keys
|
||||
|
||||
to_update_keys = []
|
||||
for key in node_keys:
|
||||
# 遍历相关节点,处理它及其祖先节点
|
||||
# 查询该节点是否包含待处理资产
|
||||
exists = cls._is_asset_exists_in_node(asset_pk, key)
|
||||
parent_key = compute_parent_key(key)
|
||||
|
||||
if exists:
|
||||
# 如果资产在该节点,那么他及其祖先节点都不用处理
|
||||
cls._remove_ancestor_keys(parent_key, ancestor_keys)
|
||||
continue
|
||||
else:
|
||||
# 不存在,要更新本节点
|
||||
to_update_keys.append(key)
|
||||
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
|
||||
# 判断是否在集合里,来区分是否已被处理过
|
||||
while parent_key and parent_key in ancestor_keys:
|
||||
exists = cls._is_asset_exists_in_node(asset_pk, parent_key)
|
||||
if exists:
|
||||
cls._remove_ancestor_keys(parent_key, ancestor_keys)
|
||||
break
|
||||
else:
|
||||
to_update_keys.append(parent_key)
|
||||
ancestor_keys.remove(parent_key)
|
||||
parent_key = compute_parent_key(parent_key)
|
||||
|
||||
Node.objects.filter(key__in=to_update_keys).update(
|
||||
assets_amount=operator(F('assets_amount'), 1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ensure_in_real_or_default_org
|
||||
@NodeTreeUpdateLock()
|
||||
def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add):
|
||||
"""
|
||||
一个节点与多个资产关系变化时,更新计数
|
||||
|
||||
:param node: 节点实例
|
||||
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
|
||||
:param operator: 操作
|
||||
* -> Node
|
||||
# -> Asset
|
||||
|
||||
* [3]
|
||||
/ \
|
||||
* * [2]
|
||||
/ \
|
||||
* * [1]
|
||||
/ / \
|
||||
* [a] # # [b]
|
||||
|
||||
"""
|
||||
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
|
||||
ancestor_keys = node.get_ancestor_keys(with_self=True)
|
||||
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
|
||||
to_update = []
|
||||
for ancestor in ancestors:
|
||||
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
|
||||
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
|
||||
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
|
||||
|
||||
asset_pk_set -= set(Asset.objects.filter(
|
||||
id__in=asset_pk_set
|
||||
).filter(
|
||||
Q(nodes__key__istartswith=f'{ancestor.key}:') |
|
||||
Q(nodes__key=ancestor.key)
|
||||
).distinct().values_list('id', flat=True))
|
||||
if not asset_pk_set:
|
||||
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
|
||||
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
|
||||
# 处理了
|
||||
break
|
||||
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
|
||||
to_update.append(ancestor)
|
||||
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
|
|
@ -0,0 +1,33 @@
|
|||
from celery import shared_task
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from orgs.models import Organization
|
||||
from orgs.utils import tmp_to_org
|
||||
from ops.celery.decorator import register_as_period_task
|
||||
from assets.utils import check_node_assets_amount
|
||||
|
||||
from common.utils.lock import AcquireFailed
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@shared_task
|
||||
def check_node_assets_amount_task(orgid=None):
|
||||
if orgid is None:
|
||||
orgs = [*Organization.objects.all(), Organization.default()]
|
||||
else:
|
||||
orgs = [Organization.get_instance(orgid)]
|
||||
|
||||
for org in orgs:
|
||||
try:
|
||||
with tmp_to_org(org):
|
||||
check_node_assets_amount()
|
||||
except AcquireFailed:
|
||||
logger.error(_('The task of self-checking is already running and cannot be started repeatedly'))
|
||||
|
||||
|
||||
@register_as_period_task(crontab='0 2 * * *')
|
||||
@shared_task
|
||||
def check_node_assets_amount_period_task():
|
||||
check_node_assets_amount_task()
|
|
@ -5,12 +5,45 @@ from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none,
|
|||
from common.http import is_true
|
||||
from common.struct import Stack
|
||||
from common.db.models import output_as_string
|
||||
from orgs.utils import ensure_in_real_or_default_org, current_org
|
||||
|
||||
from .models import Node
|
||||
from .locks import NodeTreeUpdateLock
|
||||
from .models import Node, Asset
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@NodeTreeUpdateLock()
|
||||
@ensure_in_real_or_default_org
|
||||
def check_node_assets_amount():
|
||||
logger.info(f'Check node assets amount {current_org}')
|
||||
nodes = list(Node.objects.all().only('id', 'key', 'assets_amount'))
|
||||
nodeid_assetid_pairs = list(Asset.nodes.through.objects.all().values_list('node_id', 'asset_id'))
|
||||
|
||||
nodekey_assetids_mapper = defaultdict(set)
|
||||
nodeid_nodekey_mapper = {}
|
||||
for node in nodes:
|
||||
nodeid_nodekey_mapper[node.id] = node.key
|
||||
|
||||
for nodeid, assetid in nodeid_assetid_pairs:
|
||||
if nodeid not in nodeid_nodekey_mapper:
|
||||
continue
|
||||
nodekey = nodeid_nodekey_mapper[nodeid]
|
||||
nodekey_assetids_mapper[nodekey].add(assetid)
|
||||
|
||||
util = NodeAssetsUtil(nodes, nodekey_assetids_mapper)
|
||||
util.generate()
|
||||
|
||||
to_updates = []
|
||||
for node in nodes:
|
||||
assets_amount = util.get_assets_amount(node.key)
|
||||
if node.assets_amount != assets_amount:
|
||||
logger.error(f'Node[{node.key}] assets amount error {node.assets_amount} != {assets_amount}')
|
||||
node.assets_amount = assets_amount
|
||||
to_updates.append(node)
|
||||
Node.objects.bulk_update(to_updates, fields=('assets_amount',))
|
||||
|
||||
|
||||
def is_query_node_all_assets(request):
|
||||
request = request
|
||||
query_all_arg = request.query_params.get('all', 'true')
|
||||
|
@ -104,5 +137,3 @@ class NodeAssetsUtil:
|
|||
util = cls(nodes, mapping)
|
||||
util.generate()
|
||||
return util
|
||||
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from django.db import transaction
|
|||
from common.utils import get_logger
|
||||
from common.utils.inspect import copy_function_args
|
||||
from apps.jumpserver.const import CONFIG
|
||||
from common.local import thread_local
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -16,24 +17,28 @@ class AcquireFailed(RuntimeError):
|
|||
pass
|
||||
|
||||
|
||||
class LockHasTimeOut(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class DistributedLock(RedisLock):
|
||||
def __init__(self, name, blocking=True, expire=None, release_lock_on_transaction_commit=False,
|
||||
release_raise_exc=False, auto_renewal_seconds=60*2):
|
||||
def __init__(self, name, *, expire=None, release_on_transaction_commit=False,
|
||||
reentrant=False, release_raise_exc=False, auto_renewal_seconds=60):
|
||||
"""
|
||||
使用 redis 构造的分布式锁
|
||||
|
||||
:param name:
|
||||
锁的名字,要全局唯一
|
||||
:param blocking:
|
||||
该参数只在锁作为装饰器或者 `with` 时有效。
|
||||
:param expire:
|
||||
锁的过期时间
|
||||
:param release_lock_on_transaction_commit:
|
||||
:param release_on_transaction_commit:
|
||||
是否在当前事务结束后再释放锁
|
||||
:param release_raise_exc:
|
||||
释放锁时,如果没有持有锁是否抛异常或静默
|
||||
:param auto_renewal_seconds:
|
||||
当持有一个无限期锁的时候,刷新锁的时间,具体参考 `redis_lock.Lock#auto_renewal`
|
||||
:param reentrant:
|
||||
是否可重入
|
||||
"""
|
||||
self.kwargs_copy = copy_function_args(self.__init__, locals())
|
||||
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
|
||||
|
@ -45,28 +50,20 @@ class DistributedLock(RedisLock):
|
|||
auto_renewal = False
|
||||
|
||||
super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal)
|
||||
self._blocking = blocking
|
||||
self._release_lock_on_transaction_commit = release_lock_on_transaction_commit
|
||||
self._release_on_transaction_commit = release_on_transaction_commit
|
||||
self._release_raise_exc = release_raise_exc
|
||||
self._reentrant = reentrant
|
||||
self._acquired_reentrant_lock = False
|
||||
self._thread_id = threading.current_thread().ident
|
||||
|
||||
def __enter__(self):
|
||||
thread_id = threading.current_thread().ident
|
||||
logger.debug(f'Attempt to acquire global lock: thread {thread_id} lock {self._name}')
|
||||
acquired = self.acquire(blocking=self._blocking)
|
||||
if self._blocking and not acquired:
|
||||
logger.debug(f'Not acquired lock, but blocking=True, thread {thread_id} lock {self._name}')
|
||||
raise EnvironmentError("Lock wasn't acquired, but blocking=True")
|
||||
acquired = self.acquire(blocking=True)
|
||||
if not acquired:
|
||||
logger.debug(f'Not acquired the lock, thread {thread_id} lock {self._name}')
|
||||
raise AcquireFailed
|
||||
logger.debug(f'Acquire lock success, thread {thread_id} lock {self._name}')
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
|
||||
if self._release_lock_on_transaction_commit:
|
||||
transaction.on_commit(self.release)
|
||||
else:
|
||||
self.release()
|
||||
self.release()
|
||||
|
||||
def __call__(self, func):
|
||||
@wraps(func)
|
||||
|
@ -82,9 +79,105 @@ class DistributedLock(RedisLock):
|
|||
return True
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
def locked_by_current_thread(self):
|
||||
if self.locked():
|
||||
owner_id = self.get_owner_id()
|
||||
local_owner_id = getattr(thread_local, self.name, None)
|
||||
|
||||
if local_owner_id and owner_id == local_owner_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def acquire(self, blocking=True, timeout=None):
|
||||
if self._reentrant:
|
||||
if self.locked_by_current_thread():
|
||||
self._acquired_reentrant_lock = True
|
||||
logger.debug(
|
||||
f'I[{self.id}] reentry lock[{self.name}] in thread[{self._thread_id}].')
|
||||
return True
|
||||
|
||||
logger.debug(f'I[{self.id}] attempt acquire reentrant-lock[{self.name}].')
|
||||
acquired = super().acquire(blocking=blocking, timeout=timeout)
|
||||
if acquired:
|
||||
logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] now.')
|
||||
setattr(thread_local, self.name, self.id)
|
||||
else:
|
||||
logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] failed.')
|
||||
return acquired
|
||||
else:
|
||||
logger.debug(f'I[{self.id}] attempt acquire lock[{self.name}].')
|
||||
acquired = super().acquire(blocking=blocking, timeout=timeout)
|
||||
logger.debug(f'I[{self.id}] acquired lock[{self.name}] {acquired}.')
|
||||
return acquired
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def _raise_exc_with_log(self, msg, *, exc_cls=NotAcquired):
|
||||
e = exc_cls(msg)
|
||||
logger.error(msg)
|
||||
self._raise_exc(e)
|
||||
|
||||
def _raise_exc(self, e):
|
||||
if self._release_raise_exc:
|
||||
raise e
|
||||
|
||||
def _release_on_reentrant_locked_by_brother(self):
|
||||
if self._acquired_reentrant_lock:
|
||||
self._acquired_reentrant_lock = False
|
||||
logger.debug(f'I[{self.id}] released reentrant-lock[{self.name}] owner[{self.get_owner_id()}] in thread[{self._thread_id}]')
|
||||
return
|
||||
else:
|
||||
self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired by me[{self.id}].')
|
||||
|
||||
def _release_on_reentrant_locked_by_me(self):
|
||||
logger.debug(f'I[{self.id}] release reentrant-lock[{self.name}] in thread[{self._thread_id}]')
|
||||
|
||||
id = getattr(thread_local, self.name, None)
|
||||
if id != self.id:
|
||||
raise PermissionError(f'Reentrant-lock[{self.name}] is not locked by me[{self.id}], owner[{id}]')
|
||||
try:
|
||||
super().release()
|
||||
except AcquireFailed as e:
|
||||
if self._release_raise_exc:
|
||||
raise e
|
||||
# 这里要保证先删除 thread_local 的标记,
|
||||
delattr(thread_local, self.name)
|
||||
except AttributeError:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
# 这里处理的是边界情况,
|
||||
# 判断锁是我的 -> 锁超时 -> 释放锁报错
|
||||
# 此时的报错应该被静默
|
||||
self._release_redis_lock()
|
||||
except NotAcquired:
|
||||
pass
|
||||
|
||||
def _release_redis_lock(self):
|
||||
# 最底层 api
|
||||
super().release()
|
||||
|
||||
def _release(self):
|
||||
try:
|
||||
self._release_redis_lock()
|
||||
except NotAcquired as e:
|
||||
logger.error(f'I[{self.id}] release lock[{self.name}] failed {e}')
|
||||
self._raise_exc(e)
|
||||
|
||||
def release(self):
|
||||
_release = self._release
|
||||
|
||||
# 处理可重入锁
|
||||
if self._reentrant:
|
||||
if self.locked_by_current_thread():
|
||||
if self.locked_by_me():
|
||||
_release = self._release_on_reentrant_locked_by_me
|
||||
else:
|
||||
_release = self._release_on_reentrant_locked_by_brother
|
||||
else:
|
||||
self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired in current-thread[{self._thread_id}]')
|
||||
|
||||
# 处理是否在事务提交时才释放锁
|
||||
if self._release_on_transaction_commit:
|
||||
logger.debug(f'I[{self.id}] release lock[{self.name}] on transaction commit ...')
|
||||
transaction.on_commit(_release)
|
||||
else:
|
||||
_release()
|
||||
|
|
|
@ -186,6 +186,10 @@ def org_aware_func(org_arg_name):
|
|||
current_org = LocalProxy(get_current_org)
|
||||
|
||||
|
||||
def ensure_in_real_or_default_org():
|
||||
if not current_org or current_org.is_root():
|
||||
raise ValueError('You must in a real or default org!')
|
||||
def ensure_in_real_or_default_org(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not current_org or current_org.is_root():
|
||||
raise ValueError('You must in a real or default org!')
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
|
|
@ -26,12 +26,6 @@ class AssetPermissionViewSet(BasePermissionViewSet):
|
|||
'node_id', 'node', 'asset_id', 'hostname', 'ip'
|
||||
]
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset().prefetch_related(
|
||||
"nodes", "assets", "users", "user_groups", "system_users"
|
||||
)
|
||||
return queryset
|
||||
|
||||
def filter_node(self, queryset):
|
||||
node_id = self.request.query_params.get('node_id')
|
||||
node_name = self.request.query_params.get('node')
|
||||
|
|
|
@ -14,7 +14,6 @@ from .mixin import RoleUserMixin, RoleAdminMixin
|
|||
from perms.utils.asset.user_permission import (
|
||||
UserGrantedTreeBuildUtils, get_user_all_asset_perm_ids,
|
||||
UserGrantedNodesQueryUtils, UserGrantedAssetsQueryUtils,
|
||||
QuerySetStage,
|
||||
)
|
||||
from perms.models import AssetPermission, PermNode
|
||||
from assets.models import Asset
|
||||
|
@ -44,10 +43,10 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
|
|||
def add_favorite_resource(self, data: list, nodes_query_utils, assets_query_utils):
|
||||
favorite_node = nodes_query_utils.get_favorite_node()
|
||||
|
||||
qs_state = QuerySetStage().annotate(
|
||||
favorite_assets = assets_query_utils.get_favorite_assets()
|
||||
favorite_assets = favorite_assets.annotate(
|
||||
parent_key=Value(favorite_node.key, output_field=CharField())
|
||||
).prefetch_related('platform')
|
||||
favorite_assets = assets_query_utils.get_favorite_assets(qs_stage=qs_state, only=())
|
||||
|
||||
data.extend(self.serialize_nodes([favorite_node], with_asset_amount=True))
|
||||
data.extend(self.serialize_assets(favorite_assets))
|
||||
|
@ -59,13 +58,11 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
|
|||
data.extend(self.serialize_nodes(nodes, with_asset_amount=True))
|
||||
|
||||
def add_assets(self, data: list, assets_query_utils: UserGrantedAssetsQueryUtils):
|
||||
qs_stage = QuerySetStage().annotate(parent_key=F('nodes__key')).prefetch_related('platform')
|
||||
|
||||
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
|
||||
all_assets = assets_query_utils.get_direct_granted_nodes_assets(qs_stage=qs_stage)
|
||||
all_assets = assets_query_utils.get_direct_granted_nodes_assets()
|
||||
else:
|
||||
all_assets = assets_query_utils.get_all_granted_assets(qs_stage=qs_stage)
|
||||
|
||||
all_assets = assets_query_utils.get_all_granted_assets()
|
||||
all_assets = all_assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform')
|
||||
data.extend(self.serialize_assets(all_assets))
|
||||
|
||||
@tmp_to_root_org()
|
||||
|
@ -144,8 +141,6 @@ class GrantedNodeChildrenWithAssetsAsTreeApiMixin(SerializeToTreeNodeMixin,
|
|||
assets = assets_query_utils.get_node_assets(key)
|
||||
assets = assets.prefetch_related('platform')
|
||||
|
||||
user = self.user
|
||||
|
||||
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
|
||||
tree_assets = self.serialize_assets(assets, key)
|
||||
return Response(data=[*tree_nodes, *tree_assets])
|
||||
|
|
|
@ -45,7 +45,7 @@ class BasePermissionViewSet(OrgBulkModelViewSet):
|
|||
if not self.is_query_all():
|
||||
queryset = queryset.filter(users=user)
|
||||
return queryset
|
||||
groups = user.groups.all()
|
||||
groups = list(user.groups.all().values_list('id', flat=True))
|
||||
queryset = queryset.filter(
|
||||
Q(users=user) | Q(user_groups__in=groups)
|
||||
).distinct()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 3.1 on 2021-02-04 09:49
|
||||
# Generated by Django 3.1 on 2021-02-08 07:15
|
||||
|
||||
import assets.models.node
|
||||
from django.conf import settings
|
||||
|
@ -9,8 +9,8 @@ import django.db.models.deletion
|
|||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('assets', '0066_remove_node_assets_amount'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
('assets', '0065_auto_20210121_1549'),
|
||||
('perms', '0017_auto_20210104_0435'),
|
||||
]
|
||||
|
|
@ -1,37 +1,17 @@
|
|||
from rest_framework.pagination import LimitOffsetPagination
|
||||
from django.conf import settings
|
||||
from rest_framework.request import Request
|
||||
from django.db.models import Sum
|
||||
|
||||
from assets.pagination import AssetPaginationBase
|
||||
from perms.models import UserAssetGrantedTreeNodeRelation
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GrantedAssetPaginationBase(LimitOffsetPagination):
|
||||
|
||||
def paginate_queryset(self, queryset, request: Request, view=None):
|
||||
self._request = request
|
||||
self._view = view
|
||||
self._user = request.user
|
||||
return super().paginate_queryset(queryset, request, view=None)
|
||||
|
||||
def get_count(self, queryset):
|
||||
exclude_query_params = {
|
||||
self.limit_query_param,
|
||||
self.offset_query_param,
|
||||
'key', 'all', 'show_current_asset',
|
||||
'cache_policy', 'display', 'draw',
|
||||
'order',
|
||||
}
|
||||
for k, v in self._request.query_params.items():
|
||||
if k not in exclude_query_params and v is not None:
|
||||
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
|
||||
return super().get_count(queryset)
|
||||
return self.get_count_from_nodes(queryset)
|
||||
|
||||
def get_count_from_nodes(self, queryset):
|
||||
raise NotImplementedError
|
||||
class GrantedAssetPaginationBase(AssetPaginationBase):
|
||||
def init_attrs(self, queryset, request: Request, view=None):
|
||||
super().init_attrs(queryset, request, view)
|
||||
self._user = view.user
|
||||
|
||||
|
||||
class NodeGrantedAssetPagination(GrantedAssetPaginationBase):
|
||||
|
@ -42,11 +22,13 @@ class NodeGrantedAssetPagination(GrantedAssetPaginationBase):
|
|||
return node.assets_amount
|
||||
else:
|
||||
logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} not has `pagination_node` -> {self._request.get_full_path()}')
|
||||
return super().get_count(queryset)
|
||||
return None
|
||||
|
||||
|
||||
class AllGrantedAssetPagination(GrantedAssetPaginationBase):
|
||||
def get_count_from_nodes(self, queryset):
|
||||
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
|
||||
return None
|
||||
assets_amount = sum(UserAssetGrantedTreeNodeRelation.objects.filter(
|
||||
user=self._user, node_parent_key=''
|
||||
).values_list('node_assets_amount', flat=True))
|
||||
|
|
|
@ -3,9 +3,12 @@
|
|||
|
||||
from rest_framework import serializers
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.db.models import Prefetch
|
||||
|
||||
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
|
||||
from perms.models import AssetPermission, Action
|
||||
from assets.models import Asset, Node, SystemUser
|
||||
from users.models import User, UserGroup
|
||||
|
||||
__all__ = [
|
||||
'AssetPermissionSerializer',
|
||||
|
@ -68,5 +71,11 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
|
|||
@classmethod
|
||||
def setup_eager_loading(cls, queryset):
|
||||
""" Perform necessary eager loading of data. """
|
||||
queryset = queryset.prefetch_related('users', 'user_groups', 'assets', 'nodes', 'system_users')
|
||||
queryset = queryset.prefetch_related(
|
||||
Prefetch('system_users', queryset=SystemUser.objects.only('id')),
|
||||
Prefetch('user_groups', queryset=UserGroup.objects.only('id')),
|
||||
Prefetch('users', queryset=User.objects.only('id')),
|
||||
Prefetch('assets', queryset=Asset.objects.only('id')),
|
||||
Prefetch('nodes', queryset=Node.objects.only('id'))
|
||||
)
|
||||
return queryset
|
||||
|
|
|
@ -115,8 +115,8 @@ class UnionQuerySet(QuerySet):
|
|||
def __getitem__(self, item):
|
||||
return self.__execute()[item]
|
||||
|
||||
def __next__(self):
|
||||
return next(self.__execute())
|
||||
def __iter__(self):
|
||||
return iter(self.__execute())
|
||||
|
||||
@classmethod
|
||||
def test_it(cls):
|
||||
|
@ -299,12 +299,12 @@ class UserGrantedTreeRefreshController:
|
|||
cls.remove_builed_orgs_from_users(orgs_id, users_id)
|
||||
|
||||
@classmethod
|
||||
@ensure_in_real_or_default_org
|
||||
def add_need_refresh_on_nodes_assets_relate_change(cls, node_ids, asset_ids):
|
||||
"""
|
||||
1,计算与这些资产有关的授权
|
||||
2,计算与这些节点以及祖先节点有关的授权
|
||||
"""
|
||||
ensure_in_real_or_default_org()
|
||||
|
||||
node_ids = set(node_ids)
|
||||
ancestor_node_keys = set()
|
||||
|
@ -340,8 +340,8 @@ class UserGrantedTreeRefreshController:
|
|||
cls.add_need_refresh_by_asset_perm_ids(perm_ids)
|
||||
|
||||
@classmethod
|
||||
@ensure_in_real_or_default_org
|
||||
def add_need_refresh_by_asset_perm_ids(cls, asset_perm_ids):
|
||||
ensure_in_real_or_default_org()
|
||||
|
||||
group_ids = AssetPermission.user_groups.through.objects.filter(
|
||||
assetpermission_id__in=asset_perm_ids
|
||||
|
@ -429,8 +429,8 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
|
|||
return asset_ids
|
||||
|
||||
@timeit
|
||||
@ensure_in_real_or_default_org
|
||||
def rebuild_user_granted_tree(self):
|
||||
ensure_in_real_or_default_org()
|
||||
logger.info(f'Rebuild user:{self.user} tree in org:{current_org}')
|
||||
|
||||
user = self.user
|
||||
|
@ -618,13 +618,13 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
|
|||
|
||||
class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
|
||||
|
||||
def get_favorite_assets(self, only=('id', )) -> QuerySet:
|
||||
def get_favorite_assets(self) -> QuerySet:
|
||||
favorite_asset_ids = FavoriteAsset.objects.filter(
|
||||
user=self.user
|
||||
).values_list('asset_id', flat=True)
|
||||
favorite_asset_ids = list(favorite_asset_ids)
|
||||
assets = self.get_all_granted_assets()
|
||||
assets = assets.filter(id__in=favorite_asset_ids).only(*only)
|
||||
assets = assets.filter(id__in=favorite_asset_ids)
|
||||
return assets
|
||||
|
||||
def get_ungroup_assets(self) -> AssetQuerySet:
|
||||
|
@ -670,7 +670,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
|
|||
granted_status = node.get_granted_status(self.user)
|
||||
|
||||
if granted_status == NodeFrom.granted:
|
||||
assets = Asset.objects.order_by().filter(nodes_id=node.id)
|
||||
assets = Asset.objects.order_by().filter(nodes__id=node.id)
|
||||
return assets
|
||||
elif granted_status == NodeFrom.asset:
|
||||
return self._get_indirect_granted_node_assets(node.id)
|
||||
|
@ -678,7 +678,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
|
|||
return Asset.objects.none()
|
||||
|
||||
def _get_indirect_granted_node_assets(self, id) -> AssetQuerySet:
|
||||
assets = Asset.objects.order_by().filter(nodes_id=id) & self.get_direct_granted_assets()
|
||||
assets = Asset.objects.order_by().filter(nodes__id=id).distinct() & self.get_direct_granted_assets()
|
||||
return assets
|
||||
|
||||
def _get_indirect_granted_node_all_assets(self, node) -> QuerySet:
|
||||
|
|
Loading…
Reference in New Issue