Merge pull request #9494 from jumpserver/pr@dev@perf_api_bulk_add

perf: 优化并发处理
pull/9507/head
老广 2023-02-10 15:55:12 +08:00 committed by GitHub
commit eebd6c30de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 357 additions and 420 deletions

View File

@ -1,13 +1,13 @@
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from assets.models import Asset
from accounts.const import SecretType, Source from accounts.const import SecretType, Source
from accounts.models import Account, AccountTemplate from accounts.models import Account, AccountTemplate
from accounts.tasks import push_accounts_to_assets from accounts.tasks import push_accounts_to_assets
from assets.const import Category, AllTypes from assets.const import Category, AllTypes
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField from assets.models import Asset
from common.serializers import SecretReadableMixin, BulkModelSerializer from common.serializers import SecretReadableMixin, BulkModelSerializer
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from .base import BaseAccountSerializer from .base import BaseAccountSerializer

View File

@ -1,4 +1,5 @@
from django.db.models.signals import pre_save, post_save from django.db.models.signals import post_save
from django.db.models.signals import pre_save
from django.dispatch import receiver from django.dispatch import receiver
from assets.models import Asset from assets.models import Asset

View File

@ -4,9 +4,9 @@ from assets.tasks.common import generate_automation_execution_data
from common.const.choices import Trigger from common.const.choices import Trigger
def automation_execute_start(task_name, tp, child_snapshot=None): def automation_execute_start(task_name, tp, task_snapshot=None):
from accounts.models import AutomationExecution from accounts.models import AutomationExecution
data = generate_automation_execution_data(task_name, tp, child_snapshot) data = generate_automation_execution_data(task_name, tp, task_snapshot)
while True: while True:
try: try:

View File

@ -1,13 +1,13 @@
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
from celery import shared_task from celery import shared_task
from django.utils.translation import gettext_noop
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext_noop
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
from assets.models import Node from assets.models import Node
from common.utils import get_logger from common.utils import get_logger
from orgs.utils import org_aware_func from orgs.utils import org_aware_func
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
__all__ = ['gather_asset_accounts'] __all__ = ['gather_asset_accounts']
logger = get_logger(__name__) logger = get_logger(__name__)
@ -18,11 +18,11 @@ def gather_asset_accounts_util(nodes, task_name):
from accounts.models import GatherAccountsAutomation from accounts.models import GatherAccountsAutomation
task_name = GatherAccountsAutomation.generate_unique_name(task_name) task_name = GatherAccountsAutomation.generate_unique_name(task_name)
child_snapshot = { task_snapshot = {
'nodes': [str(node.id) for node in nodes], 'nodes': [str(node.id) for node in nodes],
} }
tp = AutomationTypes.verify_account tp = AutomationTypes.verify_account
automation_execute_start(task_name, tp, child_snapshot) automation_execute_start(task_name, tp, task_snapshot)
@shared_task(queue="ansible", verbose_name=_('Gather asset accounts')) @shared_task(queue="ansible", verbose_name=_('Gather asset accounts'))

View File

@ -1,10 +1,10 @@
from celery import shared_task from celery import shared_task
from django.utils.translation import gettext_noop, ugettext_lazy as _ from django.utils.translation import gettext_noop, ugettext_lazy as _
from common.utils import get_logger
from orgs.utils import org_aware_func
from accounts.const import AutomationTypes from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start from accounts.tasks.common import automation_execute_start
from common.utils import get_logger
from orgs.utils import org_aware_func
logger = get_logger(__file__) logger = get_logger(__file__)
__all__ = [ __all__ = [
@ -13,14 +13,14 @@ __all__ = [
def push_util(account, assets, task_name): def push_util(account, assets, task_name):
child_snapshot = { task_snapshot = {
'secret': account.secret, 'secret': account.secret,
'secret_type': account.secret_type, 'secret_type': account.secret_type,
'accounts': [account.username], 'accounts': [account.username],
'assets': [str(asset.id) for asset in assets], 'assets': [str(asset.id) for asset in assets],
} }
tp = AutomationTypes.push_account tp = AutomationTypes.push_account
automation_execute_start(task_name, tp, child_snapshot) automation_execute_start(task_name, tp, task_snapshot)
@org_aware_func("assets") @org_aware_func("assets")

View File

@ -2,10 +2,10 @@ from celery import shared_task
from django.utils.translation import gettext_noop from django.utils.translation import gettext_noop
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from common.utils import get_logger
from assets.const import GATEWAY_NAME
from accounts.const import AutomationTypes from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start from accounts.tasks.common import automation_execute_start
from assets.const import GATEWAY_NAME
from common.utils import get_logger
from orgs.utils import org_aware_func from orgs.utils import org_aware_func
logger = get_logger(__name__) logger = get_logger(__name__)
@ -18,11 +18,11 @@ def verify_connectivity_util(assets, tp, accounts, task_name):
if not assets or not accounts: if not assets or not accounts:
return return
account_usernames = list(accounts.values_list('username', flat=True)) account_usernames = list(accounts.values_list('username', flat=True))
child_snapshot = { task_snapshot = {
'accounts': account_usernames, 'accounts': account_usernames,
'assets': [str(asset.id) for asset in assets], 'assets': [str(asset.id) for asset in assets],
} }
automation_execute_start(task_name, tp, child_snapshot) automation_execute_start(task_name, tp, task_snapshot)
@org_aware_func("assets") @org_aware_func("assets")

View File

@ -256,8 +256,6 @@ class FamilyMixin:
class NodeAllAssetsMappingMixin: class NodeAllAssetsMappingMixin:
# Use a new plan
# { org_id: { node_key: [ asset1_id, asset2_id ] } } # { org_id: { node_key: [ asset1_id, asset2_id ] } }
orgid_nodekey_assetsid_mapping = defaultdict(dict) orgid_nodekey_assetsid_mapping = defaultdict(dict)
locks_for_get_mapping_from_cache = defaultdict(threading.Lock) locks_for_get_mapping_from_cache = defaultdict(threading.Lock)
@ -273,20 +271,7 @@ class NodeAllAssetsMappingMixin:
if _mapping: if _mapping:
return _mapping return _mapping
logger.debug(f'Get node asset mapping from memory failed, acquire thread lock: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
with cls.get_lock(org_id): with cls.get_lock(org_id):
logger.debug(f'Acquired thread lock ok. check if mapping is in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
_mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id)
if _mapping:
logger.debug(f'Mapping is already in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return _mapping
_mapping = cls.get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(org_id) _mapping = cls.get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(org_id)
cls.set_node_all_asset_ids_mapping_to_memory(org_id, mapping=_mapping) cls.set_node_all_asset_ids_mapping_to_memory(org_id, mapping=_mapping)
return _mapping return _mapping
@ -302,18 +287,18 @@ class NodeAllAssetsMappingMixin:
cls.orgid_nodekey_assetsid_mapping[org_id] = mapping cls.orgid_nodekey_assetsid_mapping[org_id] = mapping
@classmethod @classmethod
def expire_node_all_asset_ids_mapping_from_memory(cls, org_id): def expire_node_all_asset_ids_memory_mapping(cls, org_id):
org_id = str(org_id) org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None) cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
@classmethod @classmethod
def expire_all_orgs_node_all_asset_ids_mapping_from_memory(cls): def expire_all_orgs_node_all_asset_ids_memory_mapping(cls):
orgs = Organization.objects.all() orgs = Organization.objects.all()
org_ids = [str(org.id) for org in orgs] org_ids = [str(org.id) for org in orgs]
org_ids.append(Organization.ROOT_ID) org_ids.append(Organization.ROOT_ID)
for id in org_ids: for i in org_ids:
cls.expire_node_all_asset_ids_mapping_from_memory(id) cls.expire_node_all_asset_ids_memory_mapping(i)
# get order: from memory -> (from cache -> to generate) # get order: from memory -> (from cache -> to generate)
@classmethod @classmethod
@ -332,25 +317,18 @@ class NodeAllAssetsMappingMixin:
return _mapping return _mapping
_mapping = cls.generate_node_all_asset_ids_mapping(org_id) _mapping = cls.generate_node_all_asset_ids_mapping(org_id)
cls.set_node_all_asset_ids_mapping_to_cache(org_id=org_id, mapping=_mapping) cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
return _mapping return _mapping
@classmethod @classmethod
def get_node_all_asset_ids_mapping_from_cache(cls, org_id): def get_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id) cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
mapping = cache.get(cache_key) mapping = cache.get(cache_key)
logger.info(f'Get node asset mapping from cache {bool(mapping)}: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return mapping return mapping
@classmethod @classmethod
def set_node_all_asset_ids_mapping_to_cache(cls, org_id, mapping): def expire_node_all_asset_ids_cache_mapping(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
@classmethod
def expire_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id) cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.delete(cache_key) cache.delete(cache_key)
@ -411,6 +389,14 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key) q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key)
return Asset.objects.filter(q).distinct() return Asset.objects.filter(q).distinct()
def get_assets_amount(self):
q = Q(node__key__startswith=f'{self.key}:') | Q(node__key=self.key)
return self.assets.through.objects.filter(q).count()
def get_assets_account_by_children(self):
children = self.get_all_children().values_list()
return self.assets.through.objects.filter(node_id__in=children).count()
@classmethod @classmethod
def get_node_all_assets_by_key_v2(cls, key): def get_node_all_assets_by_key_v2(cls, key):
# 最初的写法是: # 最初的写法是:

View File

@ -130,7 +130,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
] ]
read_only_fields = [ read_only_fields = [
'category', 'type', 'connectivity', 'category', 'type', 'connectivity',
'date_verified', 'created_by', 'date_created' 'date_verified', 'created_by', 'date_created',
] ]
fields = fields_small + fields_fk + fields_m2m + read_only_fields fields = fields_small + fields_fk + fields_m2m + read_only_fields
extra_kwargs = { extra_kwargs = {
@ -228,6 +228,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
node_id = request.query_params.get('node_id') node_id = request.query_params.get('node_id')
if not node_id: if not node_id:
return [] return []
nodes = Node.objects.filter(id=node_id)
return nodes
def is_valid(self, raise_exception=False): def is_valid(self, raise_exception=False):
self._set_protocols_default() self._set_protocols_default()

View File

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models.signals import ( from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete, pre_save m2m_changed, pre_delete, post_delete, pre_save, post_save
) )
from django.dispatch import receiver from django.dispatch import receiver
from django.utils.translation import gettext_noop from django.utils.translation import gettext_noop
from assets.models import Asset, Node, Cloud, Device, Host, Web, Database from assets.models import Asset, Node, Host, Database, Device, Web, Cloud
from assets.tasks import test_assets_connectivity_task from assets.tasks import test_assets_connectivity_task, gather_assets_facts_task
from common.const.signals import POST_ADD, POST_REMOVE, PRE_REMOVE from common.const.signals import POST_REMOVE, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run from common.decorators import on_transaction_commit, merge_delay_run, key_by_org
from common.utils import get_logger from common.utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
@ -20,15 +20,33 @@ def on_node_pre_save(sender, instance: Node, **kwargs):
instance.parent_key = instance.compute_parent_key() instance.parent_key = instance.compute_parent_key()
@merge_delay_run(ttl=10) @merge_delay_run(ttl=5, key=key_by_org)
def test_assets_connectivity_handler(*assets): def test_assets_connectivity_handler(*assets):
task_name = gettext_noop("Test assets connectivity ") task_name = gettext_noop("Test assets connectivity ")
test_assets_connectivity_task.delay(assets, task_name) test_assets_connectivity_task.delay(assets, task_name)
@merge_delay_run(ttl=10) @merge_delay_run(ttl=5, key=key_by_org)
def gather_assets_facts_handler(*assets): def gather_assets_facts_handler(*assets):
pass if not assets:
logger.info("No assets to update hardware info")
return
name = gettext_noop("Gather asset hardware info")
gather_assets_facts_task.delay(assets=assets, task_name=name)
@merge_delay_run(ttl=5, key=key_by_org)
def ensure_asset_has_node(*assets):
asset_ids = [asset.id for asset in assets]
has_ids = Asset.nodes.through.objects \
.filter(asset_id__in=asset_ids) \
.values_list('asset_id', flat=True)
need_ids = set(asset_ids) - set(has_ids)
if not need_ids:
return
org_root = Node.org_root()
org_root.assets.add(*need_ids)
@receiver(post_save, sender=Asset) @receiver(post_save, sender=Asset)
@ -42,38 +60,16 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
return return
logger.info("Asset create signal recv: {}".format(instance)) logger.info("Asset create signal recv: {}".format(instance))
ensure_asset_has_node(instance)
# 获取资产硬件信息 # 获取资产硬件信息
test_assets_connectivity_handler([instance]) auto_info = instance.auto_info
gather_assets_facts_handler([instance]) if auto_info.get('ping_enabled'):
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
# 确保资产存在一个节点 test_assets_connectivity_handler(instance)
has_node = instance.nodes.all().exists() if auto_info.get('gather_facts_enabled'):
if not has_node: logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
instance.nodes.add(Node.org_root()) gather_assets_facts_handler(instance)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
"""
本操作共访问 4 次数据库
当资产的节点发生变化时或者 当节点的资产关系发生变化时
节点下新增的资产添加到节点关联的系统用户中
"""
if action != POST_ADD:
return
logger.debug("Assets node add signal recv: {}".format(action))
if reverse:
nodes = [instance.key]
asset_ids = pk_set
else:
nodes = Node.objects.filter(pk__in=pk_set).values_list('key', flat=True)
asset_ids = [instance.id]
# 节点资产发生变化时,将资产关联到节点及祖先节点关联的系统用户, 只关注新增的
nodes_ancestors_keys = set()
for node in nodes:
nodes_ancestors_keys.update(Node.get_node_ancestor_keys(node, with_self=True))
RELATED_NODE_IDS = '_related_node_ids' RELATED_NODE_IDS = '_related_node_ids'
@ -82,19 +78,19 @@ RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset) @receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs): def on_asset_delete(instance: Asset, using, **kwargs):
logger.debug("Asset pre delete signal recv: {}".format(instance)) logger.debug("Asset pre delete signal recv: {}".format(instance))
node_ids = set(Node.objects.filter( node_ids = Node.objects.filter(assets=instance) \
assets=instance .distinct().values_list('id', flat=True)
).distinct().values_list('id', flat=True))
setattr(instance, RELATED_NODE_IDS, node_ids) setattr(instance, RELATED_NODE_IDS, node_ids)
m2m_changed.send( m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False, sender=Asset.nodes.through, instance=instance,
model=Node, pk_set=node_ids, using=using, action=PRE_REMOVE reverse=False, model=Node, pk_set=node_ids,
using=using, action=PRE_REMOVE
) )
@receiver(post_delete, sender=Asset) @receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs): def on_asset_post_delete(instance: Asset, using, **kwargs):
logger.debug("Asset delete signal recv: {}".format(instance)) logger.debug("Asset post delete signal recv: {}".format(instance))
node_ids = getattr(instance, RELATED_NODE_IDS, None) node_ids = getattr(instance, RELATED_NODE_IDS, None)
if node_ids: if node_ids:
m2m_changed.send( m2m_changed.send(

View File

@ -1,22 +1,21 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from operator import add, sub from operator import add, sub
from django.db.models import Q, F
from django.db.models.signals import m2m_changed
from django.dispatch import receiver from django.dispatch import receiver
from django.db.models.signals import (
m2m_changed
)
from orgs.utils import ensure_in_real_or_default_org, tmp_to_org from assets.models import Asset, Node
from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR
from common.decorators import on_transaction_commit, merge_delay_run
from common.utils import get_logger from common.utils import get_logger
from assets.models import Asset, Node, compute_parent_key from orgs.utils import tmp_to_org
from assets.locks import NodeTreeUpdateLock from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__) logger = get_logger(__file__)
@on_transaction_commit
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs): def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set` # 不允许 `pre_clear` ,因为该信号没有 `pk_set`
@ -25,136 +24,29 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
if action in refused: if action in refused:
raise ValueError raise ValueError
mapper = { logger.debug('Recv asset nodes change signal, recompute node assets amount')
PRE_ADD: add, mapper = {PRE_ADD: add, POST_REMOVE: sub}
POST_REMOVE: sub
}
if action not in mapper: if action not in mapper:
return return
operator = mapper[action]
with tmp_to_org(instance.org): with tmp_to_org(instance.org):
if reverse: if reverse:
node: Node = instance node_ids = [instance.id]
asset_pk_set = set(pk_set)
NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator)
else: else:
asset_pk = instance.id node_ids = pk_set
# 与资产直接关联的节点 update_nodes_assets_amount(*node_ids)
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator)
class NodeAssetsAmountUtils: @merge_delay_run(ttl=5)
def update_nodes_assets_amount(*node_ids):
nodes = list(Node.objects.filter(id__in=node_ids))
logger.info('Update nodes assets amount: {} nodes'.format(len(node_ids)))
@classmethod if len(node_ids) > 100:
def _remove_ancestor_keys(cls, ancestor_key, tree_set): check_node_assets_amount_task.delay()
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环 return
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
@classmethod for node in nodes:
def _is_asset_exists_in_node(cls, asset_pk, node_key): node.assets_amount = node.get_assets_amount()
exists = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).filter(id=asset_pk).exists()
return exists
@classmethod Node.objects.bulk_update(nodes, ['assets_amount'])
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_nodes_asset_amount(cls, node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = cls._is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
cls._remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = cls._is_asset_exists_in_node(asset_pk, parent_key)
if exists:
cls._remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))

View File

@ -2,42 +2,35 @@
# #
from django.db.models.signals import ( from django.db.models.signals import (
m2m_changed, post_save, post_delete post_save, post_delete, m2m_changed
) )
from django.dispatch import receiver from django.dispatch import receiver
from django.utils.functional import LazyObject from django.utils.functional import lazy
from assets.models import Asset, Node from assets.models import Node, Asset
from common.decorators import merge_delay_run
from common.signals import django_ready from common.signals import django_ready
from common.utils import get_logger from common.utils import get_logger
from common.utils.connection import RedisPubSub from common.utils.connection import RedisPubSub
from orgs.models import Organization from orgs.models import Organization
logger = get_logger(__file__) logger = get_logger(__name__)
# clear node assets mapping for memory # clear node assets mapping for memory
# ------------------------------------ # ------------------------------------
node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)()
class NodeAssetsMappingForMemoryPubSub(LazyObject): @merge_delay_run(ttl=5)
def _setup(self): def expire_node_assets_mapping(*org_ids):
self._wrapped = RedisPubSub('fm.node_all_asset_ids_memory_mapping')
node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub()
def expire_node_assets_mapping_for_memory(org_id):
# 所有进程清除(自己的 memory 数据) # 所有进程清除(自己的 memory 数据)
org_id = str(org_id)
root_org_id = Organization.ROOT_ID root_org_id = Organization.ROOT_ID
Node.expire_node_all_asset_ids_cache_mapping(root_org_id)
for org_id in set(org_ids):
org_id = str(org_id)
# 当前进程清除(cache 数据) # 当前进程清除(cache 数据)
Node.expire_node_all_asset_ids_mapping_from_cache(org_id) Node.expire_node_all_asset_ids_cache_mapping(org_id)
Node.expire_node_all_asset_ids_mapping_from_cache(root_org_id) node_assets_mapping_pub_sub.publish(org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
@receiver(post_save, sender=Node) @receiver(post_save, sender=Node)
@ -50,17 +43,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
need_expire = False need_expire = False
if need_expire: if need_expire:
expire_node_assets_mapping_for_memory(instance.org_id) expire_node_assets_mapping(instance.org_id)
@receiver(post_delete, sender=Node) @receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs): def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id) expire_node_assets_mapping(instance.org_id)
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs): def on_node_asset_change(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id) logger.debug("Recv asset nodes changed signal, expire memery node asset mapping")
expire_node_assets_mapping(instance.org_id)
@receiver(django_ready) @receiver(django_ready)
@ -69,7 +63,7 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
def handle_node_relation_change(org_id): def handle_node_relation_change(org_id):
root_org_id = Organization.ROOT_ID root_org_id = Organization.ROOT_ID
Node.expire_node_all_asset_ids_mapping_from_memory(org_id) Node.expire_node_all_asset_ids_memory_mapping(org_id)
Node.expire_node_all_asset_ids_mapping_from_memory(root_org_id) Node.expire_node_all_asset_ids_memory_mapping(root_org_id)
node_assets_mapping_for_memory_pub_sub.subscribe(handle_node_relation_change) node_assets_mapping_pub_sub.subscribe(handle_node_relation_change)

View File

@ -8,8 +8,8 @@ from common.const.choices import Trigger
from orgs.utils import current_org from orgs.utils import current_org
def generate_automation_execution_data(task_name, tp, child_snapshot=None): def generate_automation_execution_data(task_name, tp, task_snapshot=None):
child_snapshot = child_snapshot or {} task_snapshot = task_snapshot or {}
from assets.models import BaseAutomation from assets.models import BaseAutomation
try: try:
eid = current_task.request.id eid = current_task.request.id
@ -25,13 +25,13 @@ def generate_automation_execution_data(task_name, tp, child_snapshot=None):
automation_instance = BaseAutomation() automation_instance = BaseAutomation()
snapshot = automation_instance.to_attr_json() snapshot = automation_instance.to_attr_json()
snapshot.update(data) snapshot.update(data)
snapshot.update(child_snapshot) snapshot.update(task_snapshot)
return {'id': eid, 'snapshot': snapshot} return {'id': eid, 'snapshot': snapshot}
def quickstart_automation(task_name, tp, child_snapshot=None): def quickstart_automation(task_name, tp, task_snapshot=None):
from assets.models import AutomationExecution from assets.models import AutomationExecution
data = generate_automation_execution_data(task_name, tp, child_snapshot) data = generate_automation_execution_data(task_name, tp, task_snapshot)
while True: while True:
try: try:

View File

@ -1,65 +1,55 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from itertools import chain
from celery import shared_task from celery import shared_task
from django.utils.translation import gettext_noop, gettext_lazy as _ from django.utils.translation import gettext_noop, gettext_lazy as _
from assets.const import AutomationTypes from assets.const import AutomationTypes
from common.utils import get_logger from common.utils import get_logger
from orgs.utils import org_aware_func from orgs.utils import tmp_to_org
from .common import quickstart_automation from .common import quickstart_automation
logger = get_logger(__file__) logger = get_logger(__file__)
__all__ = [ __all__ = [
'update_assets_fact_util', 'gather_assets_facts_task',
'update_node_assets_hardware_info_manual', 'update_node_assets_hardware_info_manual',
'update_assets_hardware_info_manual', 'update_assets_hardware_info_manual',
] ]
def update_fact_util(assets=None, nodes=None, task_name=None): @shared_task(queue="ansible", verbose_name=_('Gather assets facts'))
def gather_assets_facts_task(assets=None, nodes=None, task_name=None):
from assets.models import GatherFactsAutomation from assets.models import GatherFactsAutomation
if task_name is None: if task_name is None:
task_name = gettext_noop("Update some assets hardware info. ") task_name = gettext_noop("Gather assets facts")
task_name = GatherFactsAutomation.generate_unique_name(task_name) task_name = GatherFactsAutomation.generate_unique_name(task_name)
nodes = nodes or [] nodes = nodes or []
assets = assets or [] assets = assets or []
child_snapshot = { resources = chain(assets, nodes)
if not resources:
raise ValueError("nodes or assets must be given")
org_id = list(resources)[0].org_id
task_snapshot = {
'assets': [str(asset.id) for asset in assets], 'assets': [str(asset.id) for asset in assets],
'nodes': [str(node.id) for node in nodes], 'nodes': [str(node.id) for node in nodes],
} }
tp = AutomationTypes.gather_facts tp = AutomationTypes.gather_facts
quickstart_automation(task_name, tp, child_snapshot)
with tmp_to_org(org_id):
quickstart_automation(task_name, tp, task_snapshot)
@org_aware_func('assets')
def update_assets_fact_util(assets=None, task_name=None):
if assets is None:
logger.info("No assets to update hardware info")
return
update_fact_util(assets=assets, task_name=task_name)
@org_aware_func('nodes')
def update_nodes_fact_util(nodes=None, task_name=None):
if nodes is None:
logger.info("No nodes to update hardware info")
return
update_fact_util(nodes=nodes, task_name=task_name)
@shared_task(queue="ansible", verbose_name=_('Manually update the hardware information of assets'))
def update_assets_hardware_info_manual(asset_ids): def update_assets_hardware_info_manual(asset_ids):
from assets.models import Asset from assets.models import Asset
assets = Asset.objects.filter(id__in=asset_ids) assets = Asset.objects.filter(id__in=asset_ids)
task_name = gettext_noop("Update assets hardware info: ") task_name = gettext_noop("Update assets hardware info: ")
update_assets_fact_util(assets=assets, task_name=task_name) gather_assets_facts_task.delay(assets=assets, task_name=task_name)
@shared_task(queue="ansible", verbose_name=_('Manually update the hardware information of assets under a node'))
def update_node_assets_hardware_info_manual(node_id): def update_node_assets_hardware_info_manual(node_id):
from assets.models import Node from assets.models import Node
node = Node.objects.get(id=node_id) node = Node.objects.get(id=node_id)
task_name = gettext_noop("Update node asset hardware information: ") task_name = gettext_noop("Update node asset hardware information: ")
update_nodes_fact_util(nodes=[node], task_name=task_name) gather_assets_facts_task.delay(nodes=[node], task_name=task_name)

View File

@ -24,8 +24,8 @@ def test_assets_connectivity_task(assets, task_name=None):
task_name = gettext_noop("Test assets connectivity ") task_name = gettext_noop("Test assets connectivity ")
task_name = PingAutomation.generate_unique_name(task_name) task_name = PingAutomation.generate_unique_name(task_name)
child_snapshot = {'assets': [str(asset.id) for asset in assets]} task_snapshot = {'assets': [str(asset.id) for asset in assets]}
quickstart_automation(task_name, AutomationTypes.ping, child_snapshot) quickstart_automation(task_name, AutomationTypes.ping, task_snapshot)
def test_assets_connectivity_manual(asset_ids): def test_assets_connectivity_manual(asset_ids):

View File

@ -1,12 +1,12 @@
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
# #
from collections import defaultdict from collections import defaultdict
from common.db.models import output_as_string
from common.struct import Stack
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit
from common.utils.http import is_true from common.utils.http import is_true
from common.struct import Stack
from common.db.models import output_as_string
from orgs.utils import ensure_in_real_or_default_org, current_org from orgs.utils import ensure_in_real_or_default_org, current_org
from ..locks import NodeTreeUpdateLock from ..locks import NodeTreeUpdateLock
from ..models import Node, Asset from ..models import Node, Asset
@ -25,11 +25,11 @@ def check_node_assets_amount():
for node in nodes: for node in nodes:
nodeid_nodekey_mapper[node.id] = node.key nodeid_nodekey_mapper[node.id] = node.key
for nodeid, assetid in nodeid_assetid_pairs: for node_id, asset_id in nodeid_assetid_pairs:
if nodeid not in nodeid_nodekey_mapper: if node_id not in nodeid_nodekey_mapper:
continue continue
nodekey = nodeid_nodekey_mapper[nodeid] node_key = nodeid_nodekey_mapper[node_id]
nodekey_assetids_mapper[nodekey].add(assetid) nodekey_assetids_mapper[node_key].add(asset_id)
util = NodeAssetsUtil(nodes, nodekey_assetids_mapper) util = NodeAssetsUtil(nodes, nodekey_assetids_mapper)
util.generate() util.generate()

View File

View File

@ -1,14 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.utils.translation import ugettext_lazy as _ from .choices import *
from .common import *
create_success_msg = _("%(name)s was created successfully") from .crontab import *
update_success_msg = _("%(name)s was updated successfully") from .http import *
FILE_END_GUARD = ">>> Content End <<<" from .signals import *
celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable
# https://docs.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties
LDAP_AD_ACCOUNT_DISABLE = 2

View File

@ -0,0 +1,11 @@
from django.utils.translation import ugettext_lazy as _
create_success_msg = _("%(name)s was created successfully")
update_success_msg = _("%(name)s was updated successfully")
FILE_END_GUARD = ">>> Content End <<<"
celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable
# https://docs.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties
LDAP_AD_ACCOUNT_DISABLE = 2

View File

@ -1,7 +1,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import asyncio
import functools import functools
import inspect import inspect
import threading
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -9,6 +11,8 @@ from concurrent.futures import ThreadPoolExecutor
from django.core.cache import cache from django.core.cache import cache
from django.db import transaction from django.db import transaction
from .utils import logger
def on_transaction_commit(func): def on_transaction_commit(func):
""" """
@ -34,54 +38,124 @@ class Singleton(object):
return self._instance[self._cls] return self._instance[self._cls]
def _run_func_if_is_last(ttl, func, *args, **kwargs): def default_suffix_key(*args, **kwargs):
ix = uuid.uuid4().__str__() return 'default'
key = f'DELAY_RUN_{func.__name__}'
cache.set(key, ix, ttl)
def key_by_org(*args, **kwargs):
return args[0].org_id
def _run_func_if_is_last(ttl, suffix_key, org, func, *args, **kwargs):
from orgs.utils import set_current_org
try:
set_current_org(org)
uid = uuid.uuid4().__str__()
suffix_key_func = suffix_key if suffix_key else default_suffix_key
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
key = f'DELAY_RUN_{func_name}_{key_suffix}'
cache.set(key, uid, ttl)
st = (ttl - 2 > 1) and ttl - 2 or 2 st = (ttl - 2 > 1) and ttl - 2 or 2
time.sleep(st) time.sleep(st)
got = cache.get(key, None) ret = cache.get(key, None)
if ix == got: if uid == ret:
func(*args, **kwargs) func(*args, **kwargs)
except Exception as e:
logger.error('delay run error: %s' % e)
executor = ThreadPoolExecutor(10) class LoopThread(threading.Thread):
def __init__(self, loop, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loop = loop
def run(self) -> None:
asyncio.set_event_loop(loop)
self.loop.run_forever()
print('loop stopped')
def delay_run(ttl=5): loop = asyncio.get_event_loop()
loop_thread = LoopThread(loop)
loop_thread.daemon = True
loop_thread.start()
executor = ThreadPoolExecutor(max_workers=5, thread_name_prefix='debouncer')
class Debouncer(object):
def __init__(self, callback, check, delay, *args, **kwargs):
self.callback = callback
self.check = check
self.delay = delay
async def __call__(self, *args, **kwargs):
await asyncio.sleep(self.delay)
ok = await self._check(*args, **kwargs)
if ok:
await loop.run_in_executor(executor, self.callback, *args)
async def _check(self, *args, **kwargs):
if asyncio.iscoroutinefunction(self.check):
return await self.check(*args, **kwargs)
return await loop.run_in_executor(executor, self.check)
def _run_func_with_org(org, func, *args, **kwargs):
from orgs.utils import set_current_org
try:
set_current_org(org)
func(*args, **kwargs)
except Exception as e:
logger.error('delay run error: %s' % e)
def delay_run(ttl=5, key=None):
"""
延迟执行函数, ttl 秒内, 只执行最后一次
:param ttl:
:param key: 是否合并参数, 一个 callback
:return:
"""
def inner(func): def inner(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
executor.submit(_run_func_if_is_last, ttl, func, *args, **kwargs) from orgs.utils import get_current_org
org = get_current_org()
suffix_key_func = key if key else default_suffix_key
uid = uuid.uuid4().__str__()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'DELAY_RUN_{func_name}_{key_suffix}'
# 延迟两倍时间,防止缓存过期,导致校验失败
cache.set(cache_key, uid, ttl * 2)
def _check_func(key_id, key_value):
ret = cache.get(key_id, None)
return key_value == ret
check_func_partial = functools.partial(_check_func, cache_key, uid)
run_func_partial = functools.partial(_run_func_with_org, org, func)
asyncio.run_coroutine_threadsafe(
Debouncer(run_func_partial, check_func_partial, ttl)(*args, **kwargs),
loop=loop
)
return wrapper return wrapper
return inner return inner
def _merge_run(ttl, func, *args, **kwargs): def merge_delay_run(ttl, key=None):
if not args or not isinstance(args[0], (list, tuple)):
raise ValueError('args[0] must be list or tuple')
key = f'DELAY_MERGE_RUN_{func.__name__}'
ix = uuid.uuid4().__str__()
value = cache.get(key, [])
value.extend(args[0])
st = (ttl - 2 > 1) and ttl - 2 or 2
time.sleep(st)
got = cache.get(key, None)
if ix == got:
func(*args, **kwargs)
def merge_delay_run(ttl):
""" """
合并 func 参数延迟执行, ttl 秒内, 只执行最后一次 合并 func 参数延迟执行, ttl 秒内, 只执行最后一次
func 参数必须是 *args func 参数必须是 *args
:param ttl: :param ttl:
:param key: 是否合并参数, 一个 callback
:return: :return:
""" """
@ -93,42 +167,50 @@ def merge_delay_run(ttl):
if not str(param).startswith('*'): if not str(param).startswith('*'):
raise ValueError('func args must be startswith *: %s' % func.__name__) raise ValueError('func args must be startswith *: %s' % func.__name__)
suffix_key_func = key if key else default_suffix_key
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args): def wrapper(*args):
key = f'DELAY_MERGE_RUN_{func.__name__}' key_suffix = suffix_key_func(*args)
values = cache.get(key, []) func_name = f'{func.__module__}_{func.__name__}'
cache_key = f'DELAY_MERGE_RUN_{func_name}_{key_suffix}'
values = cache.get(cache_key, [])
new_arg = [*values, *args] new_arg = [*values, *args]
cache.set(key, new_arg, ttl) cache.set(cache_key, new_arg, ttl)
return delay_run(ttl)(func)(*new_arg) return delay_run(ttl, suffix_key_func)(func)(*new_arg)
return wrapper return wrapper
return inner return inner
def delay_run(ttl=5): @delay_run(ttl=5)
"""
延迟执行函数, ttl 秒内, 只执行最后一次
:param ttl:
:return:
"""
def inner(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
executor.submit(_run_func_if_is_last, ttl, func, *args, **kwargs)
return wrapper
return inner
@delay_run(ttl=10)
def test_delay_run(username, year=2000): def test_delay_run(username, year=2000):
print("Hello, %s, now is %s" % (username, year)) print("Hello, %s, now is %s" % (username, year))
@merge_delay_run(ttl=10) @merge_delay_run(ttl=5, key=lambda *users: users[0][0])
def test_merge_delay_run(*users): def test_merge_delay_run(*users):
name = ','.join(users) name = ','.join(users)
time.sleep(2)
print("Hello, %s, now is %s" % (name, time.time())) print("Hello, %s, now is %s" % (name, time.time()))
@merge_delay_run(ttl=5, key=lambda *users: users[0][0])
def test_merge_delay_run(*users):
name = ','.join(users)
time.sleep(2)
print("Hello, %s, now is %s" % (name, time.time()))
def do_test():
s = time.time()
print("start : %s" % time.time())
for i in range(100):
# test_delay_run('test', year=i)
test_merge_delay_run('test %s' % i)
test_merge_delay_run('best %s' % i)
end = time.time()
using = end - s
print("end : %s, using: %s" % (end, using))

View File

@ -1,6 +1,6 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from assets.signal_handlers.node_assets_mapping import expire_node_assets_mapping_for_memory from assets.signal_handlers.node_assets_mapping import expire_node_assets_mapping
from orgs.caches import OrgResourceStatisticsCache from orgs.caches import OrgResourceStatisticsCache
from orgs.models import Organization from orgs.models import Organization
@ -10,7 +10,7 @@ def expire_node_assets_mapping():
org_ids = [*org_ids, '00000000-0000-0000-0000-000000000000'] org_ids = [*org_ids, '00000000-0000-0000-0000-000000000000']
for org_id in org_ids: for org_id in org_ids:
expire_node_assets_mapping_for_memory(org_id) expire_node_assets_mapping(org_id)
def expire_org_resource_statistics_cache(): def expire_org_resource_statistics_cache():

View File

@ -60,16 +60,18 @@ def on_request_finished_logging_db_query(sender, **kwargs):
method = current_request.method method = current_request.method
path = current_request.get_full_path() path = current_request.get_full_path()
# print(">>> [{}] {}".format(method, path)) print(">>> [{}] {}".format(method, path))
# for table_name, queries in table_queries.items(): for table_name, queries in table_queries.items():
# if table_name.startswith('rbac_') or table_name.startswith('auth_permission'): if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
# continue continue
# print("- Table: {}".format(table_name)) if len(queries) < 3:
# for i, query in enumerate(queries, 1): continue
# sql = query['sql'] print("- Table: {}".format(table_name))
# if not sql or not sql.startswith('SELECT'): for i, query in enumerate(queries, 1):
# continue sql = query['sql']
# print('\t{}. {}'.format(i, sql)) if not sql or not sql.startswith('SELECT'):
continue
print('\t{}. {}'.format(i, sql))
logger.debug(">>> [{}] {}".format(method, path)) logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters: for name, counter in counters:

View File

@ -1,19 +1,19 @@
from django.db.models.signals import post_save, pre_delete, pre_save, post_delete from django.db.models.signals import post_save, pre_delete, pre_save, post_delete
from django.dispatch import receiver from django.dispatch import receiver
from orgs.models import Organization
from assets.models import Node
from accounts.models import Account from accounts.models import Account
from assets.models import Asset, Domain
from assets.models import Node
from common.decorators import merge_delay_run
from common.utils import get_logger
from orgs.caches import OrgResourceStatisticsCache
from orgs.models import Organization
from orgs.utils import current_org
from perms.models import AssetPermission from perms.models import AssetPermission
from audits.models import UserLoginLog from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding
from terminal.models import Session
from users.models import UserGroup, User from users.models import UserGroup, User
from users.signals import pre_user_leave_org from users.signals import pre_user_leave_org
from terminal.models import Session
from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding
from assets.models import Asset, Domain
from orgs.caches import OrgResourceStatisticsCache
from orgs.utils import current_org
from common.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@ -62,24 +62,7 @@ def on_user_delete_refresh_cache(sender, instance, **kwargs):
refresh_all_orgs_user_amount_cache(instance) refresh_all_orgs_user_amount_cache(instance)
# @receiver(m2m_changed, sender=OrganizationMember) model_cache_field_mapper = {
# def on_org_user_changed_refresh_cache(sender, action, instance, reverse, pk_set, **kwargs):
# if not action.startswith(POST_PREFIX):
# return
#
# if reverse:
# orgs = Organization.objects.filter(id__in=pk_set)
# else:
# orgs = [instance]
#
# for org in orgs:
# org_cache = OrgResourceStatisticsCache(org)
# org_cache.expire('users_amount')
# OrgResourceStatisticsCache(Organization.root()).expire('users_amount')
class OrgResourceStatisticsRefreshUtil:
model_cache_field_mapper = {
Node: ['nodes_amount'], Node: ['nodes_amount'],
Domain: ['domains_amount'], Domain: ['domains_amount'],
UserGroup: ['groups_amount'], UserGroup: ['groups_amount'],
@ -87,17 +70,24 @@ class OrgResourceStatisticsRefreshUtil:
RoleBinding: ['users_amount', 'new_users_amount_this_week'], RoleBinding: ['users_amount', 'new_users_amount_this_week'],
Asset: ['assets_amount', 'new_assets_amount_this_week'], Asset: ['assets_amount', 'new_assets_amount_this_week'],
AssetPermission: ['asset_perms_amount'], AssetPermission: ['asset_perms_amount'],
}
}
class OrgResourceStatisticsRefreshUtil:
@staticmethod
@merge_delay_run(ttl=5)
def refresh_org_fields(*org_fields):
for org, cache_field_name in org_fields:
OrgResourceStatisticsCache(org).expire(*cache_field_name)
OrgResourceStatisticsCache(Organization.root()).expire(*cache_field_name)
@classmethod @classmethod
def refresh_if_need(cls, instance): def refresh_if_need(cls, instance):
cache_field_name = cls.model_cache_field_mapper.get(type(instance)) cache_field_name = model_cache_field_mapper.get(type(instance))
if not cache_field_name: if not cache_field_name:
return return
OrgResourceStatisticsCache(Organization.root()).expire(*cache_field_name) org = getattr(instance, 'org', None)
if getattr(instance, 'org', None): cls.refresh_org_fields((org, cache_field_name))
OrgResourceStatisticsCache(instance.org).expire(*cache_field_name)
@receiver(post_save) @receiver(post_save)

View File

@ -1,13 +1,10 @@
from rest_framework import generics
from rest_framework.permissions import AllowAny, IsAuthenticated
from django.conf import settings from django.conf import settings
from rest_framework import generics
from rest_framework.permissions import AllowAny
from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info
from common.utils import get_logger, lazyproperty, get_object_or_none
from authentication.models import ConnectionToken
from orgs.utils import tmp_to_root_org
from common.permissions import IsValidUserOrConnectionToken from common.permissions import IsValidUserOrConnectionToken
from common.utils import get_logger, lazyproperty
from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info
from .. import serializers from .. import serializers
from ..utils import get_interface_setting_or_default from ..utils import get_interface_setting_or_default
@ -58,6 +55,3 @@ class PublicSettingApi(OpenPublicSettingApi):
# 提前把异常爆出来 # 提前把异常爆出来
values[name] = getattr(settings, name) values[name] = getattr(settings, name)
return values return values

View File

@ -1,4 +1,4 @@
from django.db.models.signals import post_save, post_delete from django.db.models.signals import post_delete, post_save
from django.dispatch import receiver from django.dispatch import receiver
from assets.models import Asset from assets.models import Asset

View File

@ -50,7 +50,7 @@ def main():
if resource == 'all': if resource == 'all':
generator_cls = resource_generator_mapper.values() generator_cls = resource_generator_mapper.values()
else: else:
generator_cls.push(resource_generator_mapper[resource]) generator_cls.append(resource_generator_mapper[resource])
for _cls in generator_cls: for _cls in generator_cls:
generator = _cls(org_id=org_id, batch_size=batch_size) generator = _cls(org_id=org_id, batch_size=batch_size)

View File

@ -1,7 +1,8 @@
#!/usr/bin/python #!/usr/bin/python
from random import seed import time
from itertools import islice from itertools import islice
from random import seed
from orgs.models import Organization from orgs.models import Organization
@ -18,7 +19,6 @@ class FakeDataGenerator:
o = Organization.get_instance(org_id, default=Organization.default()) o = Organization.get_instance(org_id, default=Organization.default())
if o: if o:
o.change_to() o.change_to()
print('Current org is: {}'.format(o))
return o return o
def do_generate(self, batch, batch_size): def do_generate(self, batch, batch_size):
@ -38,8 +38,11 @@ class FakeDataGenerator:
batch = list(islice(counter, self.batch_size)) batch = list(islice(counter, self.batch_size))
if not batch: if not batch:
break break
start = time.time()
self.do_generate(batch, self.batch_size) self.do_generate(batch, self.batch_size)
end = time.time()
using = end - start
from_size = created from_size = created
created += len(batch) created += len(batch)
print('Generate %s: %s-%s' % (self.resource, from_size, created)) print('Generate %s: %s-%s [{}s]' % (self.resource, from_size, created, using))
self.after_generate() self.after_generate()