perf: 优化用户详情页授权列表加载速度&添加可重入锁

pull/5634/head
xinwen 2021-02-08 14:59:20 +08:00 committed by 老广
parent e599bca951
commit 9be3cbb936
22 changed files with 434 additions and 124 deletions

View File

@ -4,9 +4,7 @@ from assets.api import FilterAssetByNodeMixin
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView from rest_framework.generics import RetrieveAPIView
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.decorators import method_decorator
from assets.locks import NodeTreeUpdateLock
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet

View File

@ -2,7 +2,7 @@ from typing import List
from common.utils.common import timeit from common.utils.common import timeit
from assets.models import Node, Asset from assets.models import Node, Asset
from assets.pagination import AssetLimitOffsetPagination from assets.pagination import NodeAssetTreePagination
from common.utils import lazyproperty from common.utils import lazyproperty
from assets.utils import get_node, is_query_node_all_assets from assets.utils import get_node, is_query_node_all_assets
@ -81,7 +81,7 @@ class SerializeToTreeNodeMixin:
class FilterAssetByNodeMixin: class FilterAssetByNodeMixin:
pagination_class = AssetLimitOffsetPagination pagination_class = NodeAssetTreePagination
@lazyproperty @lazyproperty
def is_query_node_all_assets(self): def is_query_node_all_assets(self):

View File

@ -8,7 +8,6 @@ from rest_framework.response import Response
from rest_framework.decorators import action from rest_framework.decorators import action
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404 from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed from django.db.models.signals import m2m_changed
from common.const.http import POST from common.const.http import POST
@ -25,10 +24,10 @@ from ..models import Node
from ..tasks import ( from ..tasks import (
update_node_assets_hardware_info_manual, update_node_assets_hardware_info_manual,
test_node_assets_connectivity_manual, test_node_assets_connectivity_manual,
check_node_assets_amount_task
) )
from .. import serializers from .. import serializers
from .mixin import SerializeToTreeNodeMixin from .mixin import SerializeToTreeNodeMixin
from assets.locks import NodeTreeUpdateLock
logger = get_logger(__file__) logger = get_logger(__file__)
@ -54,6 +53,11 @@ class NodeViewSet(OrgModelViewSet):
serializer.validated_data["key"] = child_key serializer.validated_data["key"] = child_key
serializer.save() serializer.save()
@action(methods=[POST], detail=False, url_path='check_assets_amount_task')
def check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
def perform_update(self, serializer): def perform_update(self, serializer):
node = self.get_object() node = self.get_object()
if node.is_org_root() and node.value != serializer.validated_data['value']: if node.is_org_root() and node.value != serializer.validated_data['value']:

View File

@ -15,7 +15,6 @@ class NodeTreeUpdateLock(DistributedLock):
) )
return name return name
def __init__(self, blocking=True): def __init__(self):
name = self.get_name() name = self.get_name()
super().__init__(name=name, blocking=blocking, super().__init__(name=name, release_on_transaction_commit=True, reentrant=True)
release_lock_on_transaction_commit=True)

View File

@ -1,4 +1,4 @@
# Generated by Django 3.1 on 2021-02-04 09:49 # Generated by Django 3.1 on 2021-02-08 10:02
from django.db import migrations from django.db import migrations
@ -10,8 +10,8 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.RemoveField( migrations.AlterModelOptions(
model_name='node', name='asset',
name='assets_amount', options={'ordering': ['hostname'], 'verbose_name': 'Asset'},
), ),
] ]

View File

@ -353,4 +353,4 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
class Meta: class Meta:
unique_together = [('org_id', 'hostname')] unique_together = [('org_id', 'hostname')]
verbose_name = _("Asset") verbose_name = _("Asset")
ordering = ["hostname", "ip"] ordering = ["hostname", ]

View File

@ -425,11 +425,6 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
node_ids.update(_ids) node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct() return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
@property
def assets_amount(self):
assets_id = self.get_all_assets_id()
return len(assets_id)
def get_all_assets_id(self): def get_all_assets_id(self):
assets_id = self.get_all_assets_id_by_node_key(org_id=self.org_id, node_key=self.key) assets_id = self.get_all_assets_id_by_node_key(org_id=self.org_id, node_key=self.key)
return set(assets_id) return set(assets_id)
@ -550,6 +545,7 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
date_create = models.DateTimeField(auto_now_add=True) date_create = models.DateTimeField(auto_now_add=True)
parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"), parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"),
db_index=True, default='') db_index=True, default='')
assets_amount = models.IntegerField(default=0)
objects = OrgManager.from_queryset(NodeQuerySet)() objects = OrgManager.from_queryset(NodeQuerySet)()
is_node = True is_node = True

View File

@ -1,39 +1,51 @@
from rest_framework.pagination import LimitOffsetPagination from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request from rest_framework.request import Request
from common.utils import get_logger
from assets.models import Node from assets.models import Node
logger = get_logger(__name__)
class AssetPaginationBase(LimitOffsetPagination):
def init_attrs(self, queryset, request: Request, view=None):
self._request = request
self._view = view
self._user = request.user
def paginate_queryset(self, queryset, request: Request, view=None):
self.init_attrs(queryset, request, view)
return super().paginate_queryset(queryset, request, view=None)
class AssetLimitOffsetPagination(LimitOffsetPagination):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
def get_count(self, queryset): def get_count(self, queryset):
"""
1. 如果查询节点下的所有资产 count 使用 Node.assets_amount
2. 如果有其他过滤条件使用 super
3. 如果只查询该节点下的资产使用 super
"""
exclude_query_params = { exclude_query_params = {
self.limit_query_param, self.limit_query_param,
self.offset_query_param, self.offset_query_param,
'node', 'all', 'show_current_asset', 'key', 'all', 'show_current_asset',
'node_id', 'display', 'draw', 'fields_size', 'cache_policy', 'display', 'draw',
'order', 'node', 'node_id', 'fields_size',
} }
for k, v in self._request.query_params.items(): for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None: if k not in exclude_query_params and v is not None:
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
return super().get_count(queryset) return super().get_count(queryset)
node_assets_count = self.get_count_from_nodes(queryset)
if node_assets_count is None:
return super().get_count(queryset)
return node_assets_count
def get_count_from_nodes(self, queryset):
raise NotImplementedError
class NodeAssetTreePagination(AssetPaginationBase):
def get_count_from_nodes(self, queryset):
is_query_all = self._view.is_query_node_all_assets is_query_all = self._view.is_query_node_all_assets
if is_query_all: if is_query_all:
node = self._view.node node = self._view.node
if not node: if not node:
node = Node.org_root() node = Node.org_root()
logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}')
return node.assets_amount return node.assets_amount
return super().get_count(queryset) return None
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
return super().paginate_queryset(queryset, request, view=None)

View File

@ -1,2 +1,3 @@
from .common import * from .common import *
from .maintain_nodes_tree import * from .node_assets_amount import *
from .node_assets_mapping import *

View File

@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
#
from operator import add, sub
from django.db.models import Q, F
from django.dispatch import receiver
from django.db.models.signals import (
m2m_changed
)
from orgs.utils import ensure_in_real_or_default_org
from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR
from common.utils import get_logger
from assets.models import Asset, Node, compute_parent_key
from assets.locks import NodeTreeUpdateLock
logger = get_logger(__file__)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator)
class NodeAssetsAmountUtils:
@classmethod
def _remove_ancestor_keys(cls, ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
@classmethod
def _is_asset_exists_in_node(cls, asset_pk, node_key):
exists = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).filter(id=asset_pk).exists()
return exists
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_nodes_asset_amount(cls, node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = cls._is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
cls._remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = cls._is_asset_exists_in_node(asset_pk, parent_key)
if exists:
cls._remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))

View File

@ -0,0 +1,33 @@
from celery import shared_task
from django.utils.translation import gettext_lazy as _
from orgs.models import Organization
from orgs.utils import tmp_to_org
from ops.celery.decorator import register_as_period_task
from assets.utils import check_node_assets_amount
from common.utils.lock import AcquireFailed
from common.utils import get_logger
logger = get_logger(__file__)
@shared_task
def check_node_assets_amount_task(orgid=None):
if orgid is None:
orgs = [*Organization.objects.all(), Organization.default()]
else:
orgs = [Organization.get_instance(orgid)]
for org in orgs:
try:
with tmp_to_org(org):
check_node_assets_amount()
except AcquireFailed:
logger.error(_('The task of self-checking is already running and cannot be started repeatedly'))
@register_as_period_task(crontab='0 2 * * *')
@shared_task
def check_node_assets_amount_period_task():
check_node_assets_amount_task()

View File

@ -5,12 +5,45 @@ from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none,
from common.http import is_true from common.http import is_true
from common.struct import Stack from common.struct import Stack
from common.db.models import output_as_string from common.db.models import output_as_string
from orgs.utils import ensure_in_real_or_default_org, current_org
from .models import Node from .locks import NodeTreeUpdateLock
from .models import Node, Asset
logger = get_logger(__file__) logger = get_logger(__file__)
@NodeTreeUpdateLock()
@ensure_in_real_or_default_org
def check_node_assets_amount():
logger.info(f'Check node assets amount {current_org}')
nodes = list(Node.objects.all().only('id', 'key', 'assets_amount'))
nodeid_assetid_pairs = list(Asset.nodes.through.objects.all().values_list('node_id', 'asset_id'))
nodekey_assetids_mapper = defaultdict(set)
nodeid_nodekey_mapper = {}
for node in nodes:
nodeid_nodekey_mapper[node.id] = node.key
for nodeid, assetid in nodeid_assetid_pairs:
if nodeid not in nodeid_nodekey_mapper:
continue
nodekey = nodeid_nodekey_mapper[nodeid]
nodekey_assetids_mapper[nodekey].add(assetid)
util = NodeAssetsUtil(nodes, nodekey_assetids_mapper)
util.generate()
to_updates = []
for node in nodes:
assets_amount = util.get_assets_amount(node.key)
if node.assets_amount != assets_amount:
logger.error(f'Node[{node.key}] assets amount error {node.assets_amount} != {assets_amount}')
node.assets_amount = assets_amount
to_updates.append(node)
Node.objects.bulk_update(to_updates, fields=('assets_amount',))
def is_query_node_all_assets(request): def is_query_node_all_assets(request):
request = request request = request
query_all_arg = request.query_params.get('all', 'true') query_all_arg = request.query_params.get('all', 'true')
@ -104,5 +137,3 @@ class NodeAssetsUtil:
util = cls(nodes, mapping) util = cls(nodes, mapping)
util.generate() util.generate()
return util return util

View File

@ -8,6 +8,7 @@ from django.db import transaction
from common.utils import get_logger from common.utils import get_logger
from common.utils.inspect import copy_function_args from common.utils.inspect import copy_function_args
from apps.jumpserver.const import CONFIG from apps.jumpserver.const import CONFIG
from common.local import thread_local
logger = get_logger(__file__) logger = get_logger(__file__)
@ -16,24 +17,28 @@ class AcquireFailed(RuntimeError):
pass pass
class LockHasTimeOut(RuntimeError):
pass
class DistributedLock(RedisLock): class DistributedLock(RedisLock):
def __init__(self, name, blocking=True, expire=None, release_lock_on_transaction_commit=False, def __init__(self, name, *, expire=None, release_on_transaction_commit=False,
release_raise_exc=False, auto_renewal_seconds=60*2): reentrant=False, release_raise_exc=False, auto_renewal_seconds=60):
""" """
使用 redis 构造的分布式锁 使用 redis 构造的分布式锁
:param name: :param name:
锁的名字要全局唯一 锁的名字要全局唯一
:param blocking:
该参数只在锁作为装饰器或者 `with` 时有效
:param expire: :param expire:
锁的过期时间 锁的过期时间
:param release_lock_on_transaction_commit: :param release_on_transaction_commit:
是否在当前事务结束后再释放锁 是否在当前事务结束后再释放锁
:param release_raise_exc: :param release_raise_exc:
释放锁时如果没有持有锁是否抛异常或静默 释放锁时如果没有持有锁是否抛异常或静默
:param auto_renewal_seconds: :param auto_renewal_seconds:
当持有一个无限期锁的时候刷新锁的时间具体参考 `redis_lock.Lock#auto_renewal` 当持有一个无限期锁的时候刷新锁的时间具体参考 `redis_lock.Lock#auto_renewal`
:param reentrant:
是否可重入
""" """
self.kwargs_copy = copy_function_args(self.__init__, locals()) self.kwargs_copy = copy_function_args(self.__init__, locals())
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD) redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
@ -45,28 +50,20 @@ class DistributedLock(RedisLock):
auto_renewal = False auto_renewal = False
super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal) super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal)
self._blocking = blocking self._release_on_transaction_commit = release_on_transaction_commit
self._release_lock_on_transaction_commit = release_lock_on_transaction_commit
self._release_raise_exc = release_raise_exc self._release_raise_exc = release_raise_exc
self._reentrant = reentrant
self._acquired_reentrant_lock = False
self._thread_id = threading.current_thread().ident
def __enter__(self): def __enter__(self):
thread_id = threading.current_thread().ident acquired = self.acquire(blocking=True)
logger.debug(f'Attempt to acquire global lock: thread {thread_id} lock {self._name}')
acquired = self.acquire(blocking=self._blocking)
if self._blocking and not acquired:
logger.debug(f'Not acquired lock, but blocking=True, thread {thread_id} lock {self._name}')
raise EnvironmentError("Lock wasn't acquired, but blocking=True")
if not acquired: if not acquired:
logger.debug(f'Not acquired the lock, thread {thread_id} lock {self._name}')
raise AcquireFailed raise AcquireFailed
logger.debug(f'Acquire lock success, thread {thread_id} lock {self._name}')
return self return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None): def __exit__(self, exc_type=None, exc_value=None, traceback=None):
if self._release_lock_on_transaction_commit: self.release()
transaction.on_commit(self.release)
else:
self.release()
def __call__(self, func): def __call__(self, func):
@wraps(func) @wraps(func)
@ -82,9 +79,105 @@ class DistributedLock(RedisLock):
return True return True
return False return False
def release(self): def locked_by_current_thread(self):
if self.locked():
owner_id = self.get_owner_id()
local_owner_id = getattr(thread_local, self.name, None)
if local_owner_id and owner_id == local_owner_id:
return True
return False
def acquire(self, blocking=True, timeout=None):
if self._reentrant:
if self.locked_by_current_thread():
self._acquired_reentrant_lock = True
logger.debug(
f'I[{self.id}] reentry lock[{self.name}] in thread[{self._thread_id}].')
return True
logger.debug(f'I[{self.id}] attempt acquire reentrant-lock[{self.name}].')
acquired = super().acquire(blocking=blocking, timeout=timeout)
if acquired:
logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] now.')
setattr(thread_local, self.name, self.id)
else:
logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] failed.')
return acquired
else:
logger.debug(f'I[{self.id}] attempt acquire lock[{self.name}].')
acquired = super().acquire(blocking=blocking, timeout=timeout)
logger.debug(f'I[{self.id}] acquired lock[{self.name}] {acquired}.')
return acquired
@property
def name(self):
return self._name
def _raise_exc_with_log(self, msg, *, exc_cls=NotAcquired):
e = exc_cls(msg)
logger.error(msg)
self._raise_exc(e)
def _raise_exc(self, e):
if self._release_raise_exc:
raise e
def _release_on_reentrant_locked_by_brother(self):
if self._acquired_reentrant_lock:
self._acquired_reentrant_lock = False
logger.debug(f'I[{self.id}] released reentrant-lock[{self.name}] owner[{self.get_owner_id()}] in thread[{self._thread_id}]')
return
else:
self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired by me[{self.id}].')
def _release_on_reentrant_locked_by_me(self):
logger.debug(f'I[{self.id}] release reentrant-lock[{self.name}] in thread[{self._thread_id}]')
id = getattr(thread_local, self.name, None)
if id != self.id:
raise PermissionError(f'Reentrant-lock[{self.name}] is not locked by me[{self.id}], owner[{id}]')
try: try:
super().release() # 这里要保证先删除 thread_local 的标记,
except AcquireFailed as e: delattr(thread_local, self.name)
if self._release_raise_exc: except AttributeError:
raise e pass
finally:
try:
# 这里处理的是边界情况,
# 判断锁是我的 -> 锁超时 -> 释放锁报错
# 此时的报错应该被静默
self._release_redis_lock()
except NotAcquired:
pass
def _release_redis_lock(self):
# 最底层 api
super().release()
def _release(self):
try:
self._release_redis_lock()
except NotAcquired as e:
logger.error(f'I[{self.id}] release lock[{self.name}] failed {e}')
self._raise_exc(e)
def release(self):
_release = self._release
# 处理可重入锁
if self._reentrant:
if self.locked_by_current_thread():
if self.locked_by_me():
_release = self._release_on_reentrant_locked_by_me
else:
_release = self._release_on_reentrant_locked_by_brother
else:
self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired in current-thread[{self._thread_id}]')
# 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit:
logger.debug(f'I[{self.id}] release lock[{self.name}] on transaction commit ...')
transaction.on_commit(_release)
else:
_release()

View File

@ -186,6 +186,10 @@ def org_aware_func(org_arg_name):
current_org = LocalProxy(get_current_org) current_org = LocalProxy(get_current_org)
def ensure_in_real_or_default_org(): def ensure_in_real_or_default_org(func):
if not current_org or current_org.is_root(): @wraps(func)
raise ValueError('You must in a real or default org!') def wrapper(*args, **kwargs):
if not current_org or current_org.is_root():
raise ValueError('You must in a real or default org!')
return func(*args, **kwargs)
return wrapper

View File

@ -26,12 +26,6 @@ class AssetPermissionViewSet(BasePermissionViewSet):
'node_id', 'node', 'asset_id', 'hostname', 'ip' 'node_id', 'node', 'asset_id', 'hostname', 'ip'
] ]
def get_queryset(self):
queryset = super().get_queryset().prefetch_related(
"nodes", "assets", "users", "user_groups", "system_users"
)
return queryset
def filter_node(self, queryset): def filter_node(self, queryset):
node_id = self.request.query_params.get('node_id') node_id = self.request.query_params.get('node_id')
node_name = self.request.query_params.get('node') node_name = self.request.query_params.get('node')

View File

@ -14,7 +14,6 @@ from .mixin import RoleUserMixin, RoleAdminMixin
from perms.utils.asset.user_permission import ( from perms.utils.asset.user_permission import (
UserGrantedTreeBuildUtils, get_user_all_asset_perm_ids, UserGrantedTreeBuildUtils, get_user_all_asset_perm_ids,
UserGrantedNodesQueryUtils, UserGrantedAssetsQueryUtils, UserGrantedNodesQueryUtils, UserGrantedAssetsQueryUtils,
QuerySetStage,
) )
from perms.models import AssetPermission, PermNode from perms.models import AssetPermission, PermNode
from assets.models import Asset from assets.models import Asset
@ -44,10 +43,10 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
def add_favorite_resource(self, data: list, nodes_query_utils, assets_query_utils): def add_favorite_resource(self, data: list, nodes_query_utils, assets_query_utils):
favorite_node = nodes_query_utils.get_favorite_node() favorite_node = nodes_query_utils.get_favorite_node()
qs_state = QuerySetStage().annotate( favorite_assets = assets_query_utils.get_favorite_assets()
favorite_assets = favorite_assets.annotate(
parent_key=Value(favorite_node.key, output_field=CharField()) parent_key=Value(favorite_node.key, output_field=CharField())
).prefetch_related('platform') ).prefetch_related('platform')
favorite_assets = assets_query_utils.get_favorite_assets(qs_stage=qs_state, only=())
data.extend(self.serialize_nodes([favorite_node], with_asset_amount=True)) data.extend(self.serialize_nodes([favorite_node], with_asset_amount=True))
data.extend(self.serialize_assets(favorite_assets)) data.extend(self.serialize_assets(favorite_assets))
@ -59,13 +58,11 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
data.extend(self.serialize_nodes(nodes, with_asset_amount=True)) data.extend(self.serialize_nodes(nodes, with_asset_amount=True))
def add_assets(self, data: list, assets_query_utils: UserGrantedAssetsQueryUtils): def add_assets(self, data: list, assets_query_utils: UserGrantedAssetsQueryUtils):
qs_stage = QuerySetStage().annotate(parent_key=F('nodes__key')).prefetch_related('platform')
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
all_assets = assets_query_utils.get_direct_granted_nodes_assets(qs_stage=qs_stage) all_assets = assets_query_utils.get_direct_granted_nodes_assets()
else: else:
all_assets = assets_query_utils.get_all_granted_assets(qs_stage=qs_stage) all_assets = assets_query_utils.get_all_granted_assets()
all_assets = all_assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform')
data.extend(self.serialize_assets(all_assets)) data.extend(self.serialize_assets(all_assets))
@tmp_to_root_org() @tmp_to_root_org()
@ -144,8 +141,6 @@ class GrantedNodeChildrenWithAssetsAsTreeApiMixin(SerializeToTreeNodeMixin,
assets = assets_query_utils.get_node_assets(key) assets = assets_query_utils.get_node_assets(key)
assets = assets.prefetch_related('platform') assets = assets.prefetch_related('platform')
user = self.user
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True) tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, key) tree_assets = self.serialize_assets(assets, key)
return Response(data=[*tree_nodes, *tree_assets]) return Response(data=[*tree_nodes, *tree_assets])

View File

@ -45,7 +45,7 @@ class BasePermissionViewSet(OrgBulkModelViewSet):
if not self.is_query_all(): if not self.is_query_all():
queryset = queryset.filter(users=user) queryset = queryset.filter(users=user)
return queryset return queryset
groups = user.groups.all() groups = list(user.groups.all().values_list('id', flat=True))
queryset = queryset.filter( queryset = queryset.filter(
Q(users=user) | Q(user_groups__in=groups) Q(users=user) | Q(user_groups__in=groups)
).distinct() ).distinct()

View File

@ -1,4 +1,4 @@
# Generated by Django 3.1 on 2021-02-04 09:49 # Generated by Django 3.1 on 2021-02-08 07:15
import assets.models.node import assets.models.node
from django.conf import settings from django.conf import settings
@ -9,8 +9,8 @@ import django.db.models.deletion
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('assets', '0066_remove_node_assets_amount'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL), migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('assets', '0065_auto_20210121_1549'),
('perms', '0017_auto_20210104_0435'), ('perms', '0017_auto_20210104_0435'),
] ]

View File

@ -1,37 +1,17 @@
from rest_framework.pagination import LimitOffsetPagination from django.conf import settings
from rest_framework.request import Request from rest_framework.request import Request
from django.db.models import Sum
from assets.pagination import AssetPaginationBase
from perms.models import UserAssetGrantedTreeNodeRelation from perms.models import UserAssetGrantedTreeNodeRelation
from common.utils import get_logger from common.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
class GrantedAssetPaginationBase(LimitOffsetPagination): class GrantedAssetPaginationBase(AssetPaginationBase):
def init_attrs(self, queryset, request: Request, view=None):
def paginate_queryset(self, queryset, request: Request, view=None): super().init_attrs(queryset, request, view)
self._request = request self._user = view.user
self._view = view
self._user = request.user
return super().paginate_queryset(queryset, request, view=None)
def get_count(self, queryset):
exclude_query_params = {
self.limit_query_param,
self.offset_query_param,
'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw',
'order',
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
return super().get_count(queryset)
return self.get_count_from_nodes(queryset)
def get_count_from_nodes(self, queryset):
raise NotImplementedError
class NodeGrantedAssetPagination(GrantedAssetPaginationBase): class NodeGrantedAssetPagination(GrantedAssetPaginationBase):
@ -42,11 +22,13 @@ class NodeGrantedAssetPagination(GrantedAssetPaginationBase):
return node.assets_amount return node.assets_amount
else: else:
logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} not has `pagination_node` -> {self._request.get_full_path()}') logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} not has `pagination_node` -> {self._request.get_full_path()}')
return super().get_count(queryset) return None
class AllGrantedAssetPagination(GrantedAssetPaginationBase): class AllGrantedAssetPagination(GrantedAssetPaginationBase):
def get_count_from_nodes(self, queryset): def get_count_from_nodes(self, queryset):
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return None
assets_amount = sum(UserAssetGrantedTreeNodeRelation.objects.filter( assets_amount = sum(UserAssetGrantedTreeNodeRelation.objects.filter(
user=self._user, node_parent_key='' user=self._user, node_parent_key=''
).values_list('node_assets_amount', flat=True)) ).values_list('node_assets_amount', flat=True))

View File

@ -3,9 +3,12 @@
from rest_framework import serializers from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.db.models import Prefetch
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from perms.models import AssetPermission, Action from perms.models import AssetPermission, Action
from assets.models import Asset, Node, SystemUser
from users.models import User, UserGroup
__all__ = [ __all__ = [
'AssetPermissionSerializer', 'AssetPermissionSerializer',
@ -68,5 +71,11 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('users', 'user_groups', 'assets', 'nodes', 'system_users') queryset = queryset.prefetch_related(
Prefetch('system_users', queryset=SystemUser.objects.only('id')),
Prefetch('user_groups', queryset=UserGroup.objects.only('id')),
Prefetch('users', queryset=User.objects.only('id')),
Prefetch('assets', queryset=Asset.objects.only('id')),
Prefetch('nodes', queryset=Node.objects.only('id'))
)
return queryset return queryset

View File

@ -115,8 +115,8 @@ class UnionQuerySet(QuerySet):
def __getitem__(self, item): def __getitem__(self, item):
return self.__execute()[item] return self.__execute()[item]
def __next__(self): def __iter__(self):
return next(self.__execute()) return iter(self.__execute())
@classmethod @classmethod
def test_it(cls): def test_it(cls):
@ -299,12 +299,12 @@ class UserGrantedTreeRefreshController:
cls.remove_builed_orgs_from_users(orgs_id, users_id) cls.remove_builed_orgs_from_users(orgs_id, users_id)
@classmethod @classmethod
@ensure_in_real_or_default_org
def add_need_refresh_on_nodes_assets_relate_change(cls, node_ids, asset_ids): def add_need_refresh_on_nodes_assets_relate_change(cls, node_ids, asset_ids):
""" """
1计算与这些资产有关的授权 1计算与这些资产有关的授权
2计算与这些节点以及祖先节点有关的授权 2计算与这些节点以及祖先节点有关的授权
""" """
ensure_in_real_or_default_org()
node_ids = set(node_ids) node_ids = set(node_ids)
ancestor_node_keys = set() ancestor_node_keys = set()
@ -340,8 +340,8 @@ class UserGrantedTreeRefreshController:
cls.add_need_refresh_by_asset_perm_ids(perm_ids) cls.add_need_refresh_by_asset_perm_ids(perm_ids)
@classmethod @classmethod
@ensure_in_real_or_default_org
def add_need_refresh_by_asset_perm_ids(cls, asset_perm_ids): def add_need_refresh_by_asset_perm_ids(cls, asset_perm_ids):
ensure_in_real_or_default_org()
group_ids = AssetPermission.user_groups.through.objects.filter( group_ids = AssetPermission.user_groups.through.objects.filter(
assetpermission_id__in=asset_perm_ids assetpermission_id__in=asset_perm_ids
@ -429,8 +429,8 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
return asset_ids return asset_ids
@timeit @timeit
@ensure_in_real_or_default_org
def rebuild_user_granted_tree(self): def rebuild_user_granted_tree(self):
ensure_in_real_or_default_org()
logger.info(f'Rebuild user:{self.user} tree in org:{current_org}') logger.info(f'Rebuild user:{self.user} tree in org:{current_org}')
user = self.user user = self.user
@ -618,13 +618,13 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase): class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
def get_favorite_assets(self, only=('id', )) -> QuerySet: def get_favorite_assets(self) -> QuerySet:
favorite_asset_ids = FavoriteAsset.objects.filter( favorite_asset_ids = FavoriteAsset.objects.filter(
user=self.user user=self.user
).values_list('asset_id', flat=True) ).values_list('asset_id', flat=True)
favorite_asset_ids = list(favorite_asset_ids) favorite_asset_ids = list(favorite_asset_ids)
assets = self.get_all_granted_assets() assets = self.get_all_granted_assets()
assets = assets.filter(id__in=favorite_asset_ids).only(*only) assets = assets.filter(id__in=favorite_asset_ids)
return assets return assets
def get_ungroup_assets(self) -> AssetQuerySet: def get_ungroup_assets(self) -> AssetQuerySet:
@ -670,7 +670,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
granted_status = node.get_granted_status(self.user) granted_status = node.get_granted_status(self.user)
if granted_status == NodeFrom.granted: if granted_status == NodeFrom.granted:
assets = Asset.objects.order_by().filter(nodes_id=node.id) assets = Asset.objects.order_by().filter(nodes__id=node.id)
return assets return assets
elif granted_status == NodeFrom.asset: elif granted_status == NodeFrom.asset:
return self._get_indirect_granted_node_assets(node.id) return self._get_indirect_granted_node_assets(node.id)
@ -678,7 +678,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
return Asset.objects.none() return Asset.objects.none()
def _get_indirect_granted_node_assets(self, id) -> AssetQuerySet: def _get_indirect_granted_node_assets(self, id) -> AssetQuerySet:
assets = Asset.objects.order_by().filter(nodes_id=id) & self.get_direct_granted_assets() assets = Asset.objects.order_by().filter(nodes__id=id).distinct() & self.get_direct_granted_assets()
return assets return assets
def _get_indirect_granted_node_all_assets(self, node) -> QuerySet: def _get_indirect_granted_node_all_assets(self, node) -> QuerySet: