mirror of https://github.com/jumpserver/jumpserver
140 lines
4.7 KiB
Python
140 lines
4.7 KiB
Python
# ~*~ coding: utf-8 ~*~
|
|
#
|
|
from collections import defaultdict
|
|
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit
|
|
from common.http import is_true
|
|
from common.struct import Stack
|
|
from common.db.models import output_as_string
|
|
from orgs.utils import ensure_in_real_or_default_org, current_org
|
|
|
|
from .locks import NodeTreeUpdateLock
|
|
from .models import Node, Asset
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
@NodeTreeUpdateLock()
|
|
@ensure_in_real_or_default_org
|
|
def check_node_assets_amount():
|
|
logger.info(f'Check node assets amount {current_org}')
|
|
nodes = list(Node.objects.all().only('id', 'key', 'assets_amount'))
|
|
nodeid_assetid_pairs = list(Asset.nodes.through.objects.all().values_list('node_id', 'asset_id'))
|
|
|
|
nodekey_assetids_mapper = defaultdict(set)
|
|
nodeid_nodekey_mapper = {}
|
|
for node in nodes:
|
|
nodeid_nodekey_mapper[node.id] = node.key
|
|
|
|
for nodeid, assetid in nodeid_assetid_pairs:
|
|
if nodeid not in nodeid_nodekey_mapper:
|
|
continue
|
|
nodekey = nodeid_nodekey_mapper[nodeid]
|
|
nodekey_assetids_mapper[nodekey].add(assetid)
|
|
|
|
util = NodeAssetsUtil(nodes, nodekey_assetids_mapper)
|
|
util.generate()
|
|
|
|
to_updates = []
|
|
for node in nodes:
|
|
assets_amount = util.get_assets_amount(node.key)
|
|
if node.assets_amount != assets_amount:
|
|
logger.error(f'Node[{node.key}] assets amount error {node.assets_amount} != {assets_amount}')
|
|
node.assets_amount = assets_amount
|
|
to_updates.append(node)
|
|
Node.objects.bulk_update(to_updates, fields=('assets_amount',))
|
|
|
|
|
|
def is_query_node_all_assets(request):
|
|
request = request
|
|
query_all_arg = request.query_params.get('all', 'true')
|
|
show_current_asset_arg = request.query_params.get('show_current_asset')
|
|
if show_current_asset_arg is not None:
|
|
return not is_true(show_current_asset_arg)
|
|
return is_true(query_all_arg)
|
|
|
|
|
|
def get_node_from_request(request):
|
|
node_id = dict_get_any(request.query_params, ['node', 'node_id'])
|
|
if not node_id:
|
|
return None
|
|
|
|
if is_uuid(node_id):
|
|
node = get_object_or_none(Node, id=node_id)
|
|
else:
|
|
node = get_object_or_none(Node, key=node_id)
|
|
return node
|
|
|
|
|
|
class NodeAssetsInfo:
|
|
__slots__ = ('key', 'assets_amount', 'assets')
|
|
|
|
def __init__(self, key, assets_amount, assets):
|
|
self.key = key
|
|
self.assets_amount = assets_amount
|
|
self.assets = assets
|
|
|
|
def __str__(self):
|
|
return self.key
|
|
|
|
|
|
class NodeAssetsUtil:
|
|
def __init__(self, nodes, nodekey_assetsid_mapper):
|
|
"""
|
|
:param nodes: 节点
|
|
:param nodekey_assetsid_mapper: 节点直接资产id的映射 {"key1": set(), "key2": set()}
|
|
"""
|
|
self.nodes = nodes
|
|
# node_id --> set(asset_id1, asset_id2)
|
|
self.nodekey_assetsid_mapper = nodekey_assetsid_mapper
|
|
self.nodekey_assetsinfo_mapper = {}
|
|
|
|
@timeit
|
|
def generate(self):
|
|
# 准备排序好的资产信息数据
|
|
infos = []
|
|
for node in self.nodes:
|
|
assets = self.nodekey_assetsid_mapper.get(node.key, set())
|
|
info = NodeAssetsInfo(key=node.key, assets_amount=0, assets=assets)
|
|
infos.append(info)
|
|
infos = sorted(infos, key=lambda i: [int(i) for i in i.key.split(':')])
|
|
# 这个守卫需要添加一下,避免最后一个无法出栈
|
|
guarder = NodeAssetsInfo(key='', assets_amount=0, assets=set())
|
|
infos.append(guarder)
|
|
|
|
stack = Stack()
|
|
for info in infos:
|
|
# 如果栈顶的不是这个节点的父祖节点,那么可以出栈了,可以计算资产数量了
|
|
while stack.top and not info.key.startswith(f'{stack.top.key}:'):
|
|
pop_info = stack.pop()
|
|
pop_info.assets_amount = len(pop_info.assets)
|
|
self.nodekey_assetsinfo_mapper[pop_info.key] = pop_info
|
|
if not stack.top:
|
|
continue
|
|
stack.top.assets.update(pop_info.assets)
|
|
stack.push(info)
|
|
|
|
def get_assets_by_key(self, key):
|
|
info = self.nodekey_assetsinfo_mapper[key]
|
|
return info['assets']
|
|
|
|
def get_assets_amount(self, key):
|
|
info = self.nodekey_assetsinfo_mapper[key]
|
|
return info.assets_amount
|
|
|
|
@classmethod
|
|
def test_it(cls):
|
|
from assets.models import Node, Asset
|
|
|
|
nodes = list(Node.objects.all())
|
|
nodes_assets = Asset.nodes.through.objects.all() \
|
|
.annotate(aid=output_as_string('asset_id')) \
|
|
.values_list('node__key', 'aid')
|
|
|
|
mapping = defaultdict(set)
|
|
for key, asset_id in nodes_assets:
|
|
mapping[key].add(asset_id)
|
|
|
|
util = cls(nodes, mapping)
|
|
util.generate()
|
|
return util
|