diff --git a/apps/assets/signals_handler.py b/apps/assets/signals_handler.py index 594da38b7..3b23de31b 100644 --- a/apps/assets/signals_handler.py +++ b/apps/assets/signals_handler.py @@ -10,6 +10,7 @@ from django.dispatch import receiver from common.utils import get_logger, timeit from common.decorator import on_transaction_commit from .models import Asset, SystemUser, Node, AuthBook +from .utils import TreeService from .tasks import ( update_assets_hardware_info_util, test_asset_connectivity_util, @@ -131,16 +132,22 @@ def on_asset_nodes_add(sender, instance=None, action='', model=None, pk_set=None if action != "post_add": return logger.debug("Assets node add signal recv: {}".format(action)) - queryset = model.objects.filter(pk__in=pk_set).values_list('id', flat=True) + queryset = model.objects.filter(pk__in=pk_set).values_list('key', flat=True) if model == Node: nodes = queryset assets = [instance] else: nodes = [instance] assets = queryset - # 节点资产发生变化时,将资产关联到节点关联的系统用户, 只关注新增的 + # 节点资产发生变化时,将资产关联到节点及祖先节点关联的系统用户, 只关注新增的 + nodes_ancestors_keys = set() + node_tree = TreeService.new() + for node in nodes: + ancestors_keys = node_tree.ancestors_ids(nid=node) + nodes_ancestors_keys.update(ancestors_keys) + system_users = SystemUser.objects.filter(nodes__key__in=nodes_ancestors_keys) + system_users_assets = defaultdict(set) - system_users = SystemUser.objects.filter(nodes__in=nodes) for system_user in system_users: system_users_assets[system_user].update(set(assets)) for system_user, _assets in system_users_assets.items(): diff --git a/apps/assets/utils.py b/apps/assets/utils.py index e0b316ad7..eaf3d502a 100644 --- a/apps/assets/utils.py +++ b/apps/assets/utils.py @@ -84,11 +84,15 @@ class TreeService(Tree): children_ids = self.all_children_ids(nid, with_self=with_self) return [self.get_node(i, deep=deep) for i in children_ids] - def ancestors(self, nid, with_self=False, deep=False): + def ancestors_ids(self, nid, with_self=True): ancestor_ids = list(self.rsearch(nid)) ancestor_ids.pop() if not with_self: ancestor_ids.pop(0) + return ancestor_ids + + def ancestors(self, nid, with_self=False, deep=False): + ancestor_ids = self.ancestors_ids(nid, with_self=with_self) return [self.get_node(i, deep=deep) for i in ancestor_ids] def get_node_full_tag(self, nid):