Merge pull request #9494 from jumpserver/pr@dev@perf_api_bulk_add

perf: 优化并发处理
pull/9507/head
老广 2023-02-10 15:55:12 +08:00 committed by GitHub
commit eebd6c30de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 357 additions and 420 deletions

View File

@ -1,13 +1,13 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from assets.models import Asset
from accounts.const import SecretType, Source
from accounts.models import Account, AccountTemplate
from accounts.tasks import push_accounts_to_assets
from assets.const import Category, AllTypes
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from assets.models import Asset
from common.serializers import SecretReadableMixin, BulkModelSerializer
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from .base import BaseAccountSerializer

View File

@ -1,4 +1,5 @@
from django.db.models.signals import pre_save, post_save
from django.db.models.signals import post_save
from django.db.models.signals import pre_save
from django.dispatch import receiver
from assets.models import Asset

View File

@ -4,9 +4,9 @@ from assets.tasks.common import generate_automation_execution_data
from common.const.choices import Trigger
def automation_execute_start(task_name, tp, child_snapshot=None):
def automation_execute_start(task_name, tp, task_snapshot=None):
from accounts.models import AutomationExecution
data = generate_automation_execution_data(task_name, tp, child_snapshot)
data = generate_automation_execution_data(task_name, tp, task_snapshot)
while True:
try:

View File

@ -1,13 +1,13 @@
# ~*~ coding: utf-8 ~*~
from celery import shared_task
from django.utils.translation import gettext_noop
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext_noop
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
from assets.models import Node
from common.utils import get_logger
from orgs.utils import org_aware_func
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
__all__ = ['gather_asset_accounts']
logger = get_logger(__name__)
@ -18,11 +18,11 @@ def gather_asset_accounts_util(nodes, task_name):
from accounts.models import GatherAccountsAutomation
task_name = GatherAccountsAutomation.generate_unique_name(task_name)
child_snapshot = {
task_snapshot = {
'nodes': [str(node.id) for node in nodes],
}
tp = AutomationTypes.verify_account
automation_execute_start(task_name, tp, child_snapshot)
automation_execute_start(task_name, tp, task_snapshot)
@shared_task(queue="ansible", verbose_name=_('Gather asset accounts'))

View File

@ -1,10 +1,10 @@
from celery import shared_task
from django.utils.translation import gettext_noop, ugettext_lazy as _
from common.utils import get_logger
from orgs.utils import org_aware_func
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
from common.utils import get_logger
from orgs.utils import org_aware_func
logger = get_logger(__file__)
__all__ = [
@ -13,14 +13,14 @@ __all__ = [
def push_util(account, assets, task_name):
child_snapshot = {
task_snapshot = {
'secret': account.secret,
'secret_type': account.secret_type,
'accounts': [account.username],
'assets': [str(asset.id) for asset in assets],
}
tp = AutomationTypes.push_account
automation_execute_start(task_name, tp, child_snapshot)
automation_execute_start(task_name, tp, task_snapshot)
@org_aware_func("assets")

View File

@ -2,10 +2,10 @@ from celery import shared_task
from django.utils.translation import gettext_noop
from django.utils.translation import ugettext as _
from common.utils import get_logger
from assets.const import GATEWAY_NAME
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
from assets.const import GATEWAY_NAME
from common.utils import get_logger
from orgs.utils import org_aware_func
logger = get_logger(__name__)
@ -18,11 +18,11 @@ def verify_connectivity_util(assets, tp, accounts, task_name):
if not assets or not accounts:
return
account_usernames = list(accounts.values_list('username', flat=True))
child_snapshot = {
task_snapshot = {
'accounts': account_usernames,
'assets': [str(asset.id) for asset in assets],
}
automation_execute_start(task_name, tp, child_snapshot)
automation_execute_start(task_name, tp, task_snapshot)
@org_aware_func("assets")

View File

@ -256,8 +256,6 @@ class FamilyMixin:
class NodeAllAssetsMappingMixin:
# Use a new plan
# { org_id: { node_key: [ asset1_id, asset2_id ] } }
orgid_nodekey_assetsid_mapping = defaultdict(dict)
locks_for_get_mapping_from_cache = defaultdict(threading.Lock)
@ -273,20 +271,7 @@ class NodeAllAssetsMappingMixin:
if _mapping:
return _mapping
logger.debug(f'Get node asset mapping from memory failed, acquire thread lock: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
with cls.get_lock(org_id):
logger.debug(f'Acquired thread lock ok. check if mapping is in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
_mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id)
if _mapping:
logger.debug(f'Mapping is already in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return _mapping
_mapping = cls.get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(org_id)
cls.set_node_all_asset_ids_mapping_to_memory(org_id, mapping=_mapping)
return _mapping
@ -302,18 +287,18 @@ class NodeAllAssetsMappingMixin:
cls.orgid_nodekey_assetsid_mapping[org_id] = mapping
@classmethod
def expire_node_all_asset_ids_mapping_from_memory(cls, org_id):
def expire_node_all_asset_ids_memory_mapping(cls, org_id):
org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
@classmethod
def expire_all_orgs_node_all_asset_ids_mapping_from_memory(cls):
def expire_all_orgs_node_all_asset_ids_memory_mapping(cls):
orgs = Organization.objects.all()
org_ids = [str(org.id) for org in orgs]
org_ids.append(Organization.ROOT_ID)
for id in org_ids:
cls.expire_node_all_asset_ids_mapping_from_memory(id)
for i in org_ids:
cls.expire_node_all_asset_ids_memory_mapping(i)
# get order: from memory -> (from cache -> to generate)
@classmethod
@ -332,25 +317,18 @@ class NodeAllAssetsMappingMixin:
return _mapping
_mapping = cls.generate_node_all_asset_ids_mapping(org_id)
cls.set_node_all_asset_ids_mapping_to_cache(org_id=org_id, mapping=_mapping)
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
return _mapping
@classmethod
def get_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
mapping = cache.get(cache_key)
logger.info(f'Get node asset mapping from cache {bool(mapping)}: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return mapping
@classmethod
def set_node_all_asset_ids_mapping_to_cache(cls, org_id, mapping):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
@classmethod
def expire_node_all_asset_ids_mapping_from_cache(cls, org_id):
def expire_node_all_asset_ids_cache_mapping(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.delete(cache_key)
@ -411,6 +389,14 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key)
return Asset.objects.filter(q).distinct()
def get_assets_amount(self):
q = Q(node__key__startswith=f'{self.key}:') | Q(node__key=self.key)
return self.assets.through.objects.filter(q).count()
def get_assets_account_by_children(self):
children = self.get_all_children().values_list()
return self.assets.through.objects.filter(node_id__in=children).count()
@classmethod
def get_node_all_assets_by_key_v2(cls, key):
# 最初的写法是:

View File

@ -130,7 +130,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
]
read_only_fields = [
'category', 'type', 'connectivity',
'date_verified', 'created_by', 'date_created'
'date_verified', 'created_by', 'date_created',
]
fields = fields_small + fields_fk + fields_m2m + read_only_fields
extra_kwargs = {
@ -228,6 +228,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
node_id = request.query_params.get('node_id')
if not node_id:
return []
nodes = Node.objects.filter(id=node_id)
return nodes
def is_valid(self, raise_exception=False):
self._set_protocols_default()

View File

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
#
from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete, pre_save
m2m_changed, pre_delete, post_delete, pre_save, post_save
)
from django.dispatch import receiver
from django.utils.translation import gettext_noop
from assets.models import Asset, Node, Cloud, Device, Host, Web, Database
from assets.tasks import test_assets_connectivity_task
from common.const.signals import POST_ADD, POST_REMOVE, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run
from assets.models import Asset, Node, Host, Database, Device, Web, Cloud
from assets.tasks import test_assets_connectivity_task, gather_assets_facts_task
from common.const.signals import POST_REMOVE, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run, key_by_org
from common.utils import get_logger
logger = get_logger(__file__)
@ -20,15 +20,33 @@ def on_node_pre_save(sender, instance: Node, **kwargs):
instance.parent_key = instance.compute_parent_key()
@merge_delay_run(ttl=10)
@merge_delay_run(ttl=5, key=key_by_org)
def test_assets_connectivity_handler(*assets):
task_name = gettext_noop("Test assets connectivity ")
test_assets_connectivity_task.delay(assets, task_name)
@merge_delay_run(ttl=10)
@merge_delay_run(ttl=5, key=key_by_org)
def gather_assets_facts_handler(*assets):
pass
if not assets:
logger.info("No assets to update hardware info")
return
name = gettext_noop("Gather asset hardware info")
gather_assets_facts_task.delay(assets=assets, task_name=name)
@merge_delay_run(ttl=5, key=key_by_org)
def ensure_asset_has_node(*assets):
asset_ids = [asset.id for asset in assets]
has_ids = Asset.nodes.through.objects \
.filter(asset_id__in=asset_ids) \
.values_list('asset_id', flat=True)
need_ids = set(asset_ids) - set(has_ids)
if not need_ids:
return
org_root = Node.org_root()
org_root.assets.add(*need_ids)
@receiver(post_save, sender=Asset)
@ -42,38 +60,16 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
return
logger.info("Asset create signal recv: {}".format(instance))
ensure_asset_has_node(instance)
# 获取资产硬件信息
test_assets_connectivity_handler([instance])
gather_assets_facts_handler([instance])
# 确保资产存在一个节点
has_node = instance.nodes.all().exists()
if not has_node:
instance.nodes.add(Node.org_root())
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
"""
本操作共访问 4 次数据库
当资产的节点发生变化时或者 当节点的资产关系发生变化时
节点下新增的资产添加到节点关联的系统用户中
"""
if action != POST_ADD:
return
logger.debug("Assets node add signal recv: {}".format(action))
if reverse:
nodes = [instance.key]
asset_ids = pk_set
else:
nodes = Node.objects.filter(pk__in=pk_set).values_list('key', flat=True)
asset_ids = [instance.id]
# 节点资产发生变化时,将资产关联到节点及祖先节点关联的系统用户, 只关注新增的
nodes_ancestors_keys = set()
for node in nodes:
nodes_ancestors_keys.update(Node.get_node_ancestor_keys(node, with_self=True))
auto_info = instance.auto_info
if auto_info.get('ping_enabled'):
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
test_assets_connectivity_handler(instance)
if auto_info.get('gather_facts_enabled'):
logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
gather_assets_facts_handler(instance)
RELATED_NODE_IDS = '_related_node_ids'
@ -82,19 +78,19 @@ RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs):
logger.debug("Asset pre delete signal recv: {}".format(instance))
node_ids = set(Node.objects.filter(
assets=instance
).distinct().values_list('id', flat=True))
node_ids = Node.objects.filter(assets=instance) \
.distinct().values_list('id', flat=True)
setattr(instance, RELATED_NODE_IDS, node_ids)
m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False,
model=Node, pk_set=node_ids, using=using, action=PRE_REMOVE
sender=Asset.nodes.through, instance=instance,
reverse=False, model=Node, pk_set=node_ids,
using=using, action=PRE_REMOVE
)
@receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs):
logger.debug("Asset delete signal recv: {}".format(instance))
logger.debug("Asset post delete signal recv: {}".format(instance))
node_ids = getattr(instance, RELATED_NODE_IDS, None)
if node_ids:
m2m_changed.send(

View File

@ -1,22 +1,21 @@
# -*- coding: utf-8 -*-
#
from operator import add, sub
from django.db.models import Q, F
from django.db.models.signals import m2m_changed
from django.dispatch import receiver
from django.db.models.signals import (
m2m_changed
)
from orgs.utils import ensure_in_real_or_default_org, tmp_to_org
from assets.models import Asset, Node
from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR
from common.decorators import on_transaction_commit, merge_delay_run
from common.utils import get_logger
from assets.models import Asset, Node, compute_parent_key
from assets.locks import NodeTreeUpdateLock
from orgs.utils import tmp_to_org
from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__)
@on_transaction_commit
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
@ -25,136 +24,29 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
logger.debug('Recv asset nodes change signal, recompute node assets amount')
mapper = {PRE_ADD: add, POST_REMOVE: sub}
if action not in mapper:
return
operator = mapper[action]
with tmp_to_org(instance.org):
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator)
node_ids = [instance.id]
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator)
node_ids = pk_set
update_nodes_assets_amount(*node_ids)
class NodeAssetsAmountUtils:
@merge_delay_run(ttl=5)
def update_nodes_assets_amount(*node_ids):
nodes = list(Node.objects.filter(id__in=node_ids))
logger.info('Update nodes assets amount: {} nodes'.format(len(node_ids)))
@classmethod
def _remove_ancestor_keys(cls, ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
if len(node_ids) > 100:
check_node_assets_amount_task.delay()
return
@classmethod
def _is_asset_exists_in_node(cls, asset_pk, node_key):
exists = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).filter(id=asset_pk).exists()
return exists
for node in nodes:
node.assets_amount = node.get_assets_amount()
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_nodes_asset_amount(cls, node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = cls._is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
cls._remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = cls._is_asset_exists_in_node(asset_pk, parent_key)
if exists:
cls._remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
Node.objects.bulk_update(nodes, ['assets_amount'])

View File

@ -2,42 +2,35 @@
#
from django.db.models.signals import (
m2m_changed, post_save, post_delete
post_save, post_delete, m2m_changed
)
from django.dispatch import receiver
from django.utils.functional import LazyObject
from django.utils.functional import lazy
from assets.models import Asset, Node
from assets.models import Node, Asset
from common.decorators import merge_delay_run
from common.signals import django_ready
from common.utils import get_logger
from common.utils.connection import RedisPubSub
from orgs.models import Organization
logger = get_logger(__file__)
logger = get_logger(__name__)
# clear node assets mapping for memory
# ------------------------------------
node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)()
class NodeAssetsMappingForMemoryPubSub(LazyObject):
def _setup(self):
self._wrapped = RedisPubSub('fm.node_all_asset_ids_memory_mapping')
node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub()
def expire_node_assets_mapping_for_memory(org_id):
@merge_delay_run(ttl=5)
def expire_node_assets_mapping(*org_ids):
# 所有进程清除(自己的 memory 数据)
org_id = str(org_id)
root_org_id = Organization.ROOT_ID
# 当前进程清除(cache 数据)
Node.expire_node_all_asset_ids_mapping_from_cache(org_id)
Node.expire_node_all_asset_ids_mapping_from_cache(root_org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
Node.expire_node_all_asset_ids_cache_mapping(root_org_id)
for org_id in set(org_ids):
org_id = str(org_id)
# 当前进程清除(cache 数据)
Node.expire_node_all_asset_ids_cache_mapping(org_id)
node_assets_mapping_pub_sub.publish(org_id)
@receiver(post_save, sender=Node)
@ -50,17 +43,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
need_expire = False
if need_expire:
expire_node_assets_mapping_for_memory(instance.org_id)
expire_node_assets_mapping(instance.org_id)
@receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
expire_node_assets_mapping(instance.org_id)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
logger.debug("Recv asset nodes changed signal, expire memery node asset mapping")
expire_node_assets_mapping(instance.org_id)
@receiver(django_ready)
@ -69,7 +63,7 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
def handle_node_relation_change(org_id):
root_org_id = Organization.ROOT_ID
Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
Node.expire_node_all_asset_ids_mapping_from_memory(root_org_id)
Node.expire_node_all_asset_ids_memory_mapping(org_id)
Node.expire_node_all_asset_ids_memory_mapping(root_org_id)
node_assets_mapping_for_memory_pub_sub.subscribe(handle_node_relation_change)
node_assets_mapping_pub_sub.subscribe(handle_node_relation_change)

View File

@ -8,8 +8,8 @@ from common.const.choices import Trigger
from orgs.utils import current_org
def generate_automation_execution_data(task_name, tp, child_snapshot=None):
child_snapshot = child_snapshot or {}
def generate_automation_execution_data(task_name, tp, task_snapshot=None):
task_snapshot = task_snapshot or {}
from assets.models import BaseAutomation
try:
eid = current_task.request.id
@ -25,13 +25,13 @@ def generate_automation_execution_data(task_name, tp, child_snapshot=None):
automation_instance = BaseAutomation()
snapshot = automation_instance.to_attr_json()
snapshot.update(data)
snapshot.update(child_snapshot)
snapshot.update(task_snapshot)
return {'id': eid, 'snapshot': snapshot}
def quickstart_automation(task_name, tp, child_snapshot=None):
def quickstart_automation(task_name, tp, task_snapshot=None):
from assets.models import AutomationExecution
data = generate_automation_execution_data(task_name, tp, child_snapshot)
data = generate_automation_execution_data(task_name, tp, task_snapshot)
while True:
try:

View File

@ -1,65 +1,55 @@
# -*- coding: utf-8 -*-
#
from itertools import chain
from celery import shared_task
from django.utils.translation import gettext_noop, gettext_lazy as _
from assets.const import AutomationTypes
from common.utils import get_logger
from orgs.utils import org_aware_func
from orgs.utils import tmp_to_org
from .common import quickstart_automation
logger = get_logger(__file__)
__all__ = [
'update_assets_fact_util',
'gather_assets_facts_task',
'update_node_assets_hardware_info_manual',
'update_assets_hardware_info_manual',
]
def update_fact_util(assets=None, nodes=None, task_name=None):
@shared_task(queue="ansible", verbose_name=_('Gather assets facts'))
def gather_assets_facts_task(assets=None, nodes=None, task_name=None):
from assets.models import GatherFactsAutomation
if task_name is None:
task_name = gettext_noop("Update some assets hardware info. ")
task_name = gettext_noop("Gather assets facts")
task_name = GatherFactsAutomation.generate_unique_name(task_name)
nodes = nodes or []
assets = assets or []
child_snapshot = {
resources = chain(assets, nodes)
if not resources:
raise ValueError("nodes or assets must be given")
org_id = list(resources)[0].org_id
task_snapshot = {
'assets': [str(asset.id) for asset in assets],
'nodes': [str(node.id) for node in nodes],
}
tp = AutomationTypes.gather_facts
quickstart_automation(task_name, tp, child_snapshot)
with tmp_to_org(org_id):
quickstart_automation(task_name, tp, task_snapshot)
@org_aware_func('assets')
def update_assets_fact_util(assets=None, task_name=None):
if assets is None:
logger.info("No assets to update hardware info")
return
update_fact_util(assets=assets, task_name=task_name)
@org_aware_func('nodes')
def update_nodes_fact_util(nodes=None, task_name=None):
if nodes is None:
logger.info("No nodes to update hardware info")
return
update_fact_util(nodes=nodes, task_name=task_name)
@shared_task(queue="ansible", verbose_name=_('Manually update the hardware information of assets'))
def update_assets_hardware_info_manual(asset_ids):
from assets.models import Asset
assets = Asset.objects.filter(id__in=asset_ids)
task_name = gettext_noop("Update assets hardware info: ")
update_assets_fact_util(assets=assets, task_name=task_name)
gather_assets_facts_task.delay(assets=assets, task_name=task_name)
@shared_task(queue="ansible", verbose_name=_('Manually update the hardware information of assets under a node'))
def update_node_assets_hardware_info_manual(node_id):
from assets.models import Node
node = Node.objects.get(id=node_id)
task_name = gettext_noop("Update node asset hardware information: ")
update_nodes_fact_util(nodes=[node], task_name=task_name)
gather_assets_facts_task.delay(nodes=[node], task_name=task_name)

View File

@ -24,8 +24,8 @@ def test_assets_connectivity_task(assets, task_name=None):
task_name = gettext_noop("Test assets connectivity ")
task_name = PingAutomation.generate_unique_name(task_name)
child_snapshot = {'assets': [str(asset.id) for asset in assets]}
quickstart_automation(task_name, AutomationTypes.ping, child_snapshot)
task_snapshot = {'assets': [str(asset.id) for asset in assets]}
quickstart_automation(task_name, AutomationTypes.ping, task_snapshot)
def test_assets_connectivity_manual(asset_ids):

View File

@ -1,12 +1,12 @@
# ~*~ coding: utf-8 ~*~
#
from collections import defaultdict
from common.db.models import output_as_string
from common.struct import Stack
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit
from common.utils.http import is_true
from common.struct import Stack
from common.db.models import output_as_string
from orgs.utils import ensure_in_real_or_default_org, current_org
from ..locks import NodeTreeUpdateLock
from ..models import Node, Asset
@ -25,11 +25,11 @@ def check_node_assets_amount():
for node in nodes:
nodeid_nodekey_mapper[node.id] = node.key
for nodeid, assetid in nodeid_assetid_pairs:
if nodeid not in nodeid_nodekey_mapper:
for node_id, asset_id in nodeid_assetid_pairs:
if node_id not in nodeid_nodekey_mapper:
continue
nodekey = nodeid_nodekey_mapper[nodeid]
nodekey_assetids_mapper[nodekey].add(assetid)
node_key = nodeid_nodekey_mapper[node_id]
nodekey_assetids_mapper[node_key].add(asset_id)
util = NodeAssetsUtil(nodes, nodekey_assetids_mapper)
util.generate()

View File

View File

@ -1,14 +1,8 @@
# -*- coding: utf-8 -*-
#
from django.utils.translation import ugettext_lazy as _
create_success_msg = _("%(name)s was created successfully")
update_success_msg = _("%(name)s was updated successfully")
FILE_END_GUARD = ">>> Content End <<<"
celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable
# https://docs.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties
LDAP_AD_ACCOUNT_DISABLE = 2
from .choices import *
from .common import *
from .crontab import *
from .http import *
from .signals import *

View File

@ -0,0 +1,11 @@
from django.utils.translation import ugettext_lazy as _
create_success_msg = _("%(name)s was created successfully")
update_success_msg = _("%(name)s was updated successfully")
FILE_END_GUARD = ">>> Content End <<<"
celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable
# https://docs.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties
LDAP_AD_ACCOUNT_DISABLE = 2

View File

@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
#
import asyncio
import functools
import inspect
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
@ -9,6 +11,8 @@ from concurrent.futures import ThreadPoolExecutor
from django.core.cache import cache
from django.db import transaction
from .utils import logger
def on_transaction_commit(func):
"""
@ -34,54 +38,124 @@ class Singleton(object):
return self._instance[self._cls]
def _run_func_if_is_last(ttl, func, *args, **kwargs):
ix = uuid.uuid4().__str__()
key = f'DELAY_RUN_{func.__name__}'
cache.set(key, ix, ttl)
st = (ttl - 2 > 1) and ttl - 2 or 2
time.sleep(st)
got = cache.get(key, None)
def default_suffix_key(*args, **kwargs):
return 'default'
if ix == got:
def key_by_org(*args, **kwargs):
return args[0].org_id
def _run_func_if_is_last(ttl, suffix_key, org, func, *args, **kwargs):
from orgs.utils import set_current_org
try:
set_current_org(org)
uid = uuid.uuid4().__str__()
suffix_key_func = suffix_key if suffix_key else default_suffix_key
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
key = f'DELAY_RUN_{func_name}_{key_suffix}'
cache.set(key, uid, ttl)
st = (ttl - 2 > 1) and ttl - 2 or 2
time.sleep(st)
ret = cache.get(key, None)
if uid == ret:
func(*args, **kwargs)
except Exception as e:
logger.error('delay run error: %s' % e)
class LoopThread(threading.Thread):
def __init__(self, loop, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loop = loop
def run(self) -> None:
asyncio.set_event_loop(loop)
self.loop.run_forever()
print('loop stopped')
loop = asyncio.get_event_loop()
loop_thread = LoopThread(loop)
loop_thread.daemon = True
loop_thread.start()
executor = ThreadPoolExecutor(max_workers=5, thread_name_prefix='debouncer')
class Debouncer(object):
def __init__(self, callback, check, delay, *args, **kwargs):
self.callback = callback
self.check = check
self.delay = delay
async def __call__(self, *args, **kwargs):
await asyncio.sleep(self.delay)
ok = await self._check(*args, **kwargs)
if ok:
await loop.run_in_executor(executor, self.callback, *args)
async def _check(self, *args, **kwargs):
if asyncio.iscoroutinefunction(self.check):
return await self.check(*args, **kwargs)
return await loop.run_in_executor(executor, self.check)
def _run_func_with_org(org, func, *args, **kwargs):
from orgs.utils import set_current_org
try:
set_current_org(org)
func(*args, **kwargs)
except Exception as e:
logger.error('delay run error: %s' % e)
executor = ThreadPoolExecutor(10)
def delay_run(ttl=5, key=None):
"""
延迟执行函数, ttl 秒内, 只执行最后一次
:param ttl:
:param key: 是否合并参数, 一个 callback
:return:
"""
def delay_run(ttl=5):
def inner(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
executor.submit(_run_func_if_is_last, ttl, func, *args, **kwargs)
from orgs.utils import get_current_org
org = get_current_org()
suffix_key_func = key if key else default_suffix_key
uid = uuid.uuid4().__str__()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'DELAY_RUN_{func_name}_{key_suffix}'
# 延迟两倍时间,防止缓存过期,导致校验失败
cache.set(cache_key, uid, ttl * 2)
def _check_func(key_id, key_value):
ret = cache.get(key_id, None)
return key_value == ret
check_func_partial = functools.partial(_check_func, cache_key, uid)
run_func_partial = functools.partial(_run_func_with_org, org, func)
asyncio.run_coroutine_threadsafe(
Debouncer(run_func_partial, check_func_partial, ttl)(*args, **kwargs),
loop=loop
)
return wrapper
return inner
def _merge_run(ttl, func, *args, **kwargs):
if not args or not isinstance(args[0], (list, tuple)):
raise ValueError('args[0] must be list or tuple')
key = f'DELAY_MERGE_RUN_{func.__name__}'
ix = uuid.uuid4().__str__()
value = cache.get(key, [])
value.extend(args[0])
st = (ttl - 2 > 1) and ttl - 2 or 2
time.sleep(st)
got = cache.get(key, None)
if ix == got:
func(*args, **kwargs)
def merge_delay_run(ttl):
def merge_delay_run(ttl, key=None):
"""
合并 func 参数延迟执行, ttl 秒内, 只执行最后一次
func 参数必须是 *args
:param ttl:
:param key: 是否合并参数, 一个 callback
:return:
"""
@ -93,42 +167,50 @@ def merge_delay_run(ttl):
if not str(param).startswith('*'):
raise ValueError('func args must be startswith *: %s' % func.__name__)
suffix_key_func = key if key else default_suffix_key
@functools.wraps(func)
def wrapper(*args):
key = f'DELAY_MERGE_RUN_{func.__name__}'
values = cache.get(key, [])
key_suffix = suffix_key_func(*args)
func_name = f'{func.__module__}_{func.__name__}'
cache_key = f'DELAY_MERGE_RUN_{func_name}_{key_suffix}'
values = cache.get(cache_key, [])
new_arg = [*values, *args]
cache.set(key, new_arg, ttl)
return delay_run(ttl)(func)(*new_arg)
cache.set(cache_key, new_arg, ttl)
return delay_run(ttl, suffix_key_func)(func)(*new_arg)
return wrapper
return inner
def delay_run(ttl=5):
"""
延迟执行函数, ttl 秒内, 只执行最后一次
:param ttl:
:return:
"""
def inner(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
executor.submit(_run_func_if_is_last, ttl, func, *args, **kwargs)
return wrapper
return inner
@delay_run(ttl=10)
@delay_run(ttl=5)
def test_delay_run(username, year=2000):
print("Hello, %s, now is %s" % (username, year))
@merge_delay_run(ttl=10)
@merge_delay_run(ttl=5, key=lambda *users: users[0][0])
def test_merge_delay_run(*users):
name = ','.join(users)
time.sleep(2)
print("Hello, %s, now is %s" % (name, time.time()))
@merge_delay_run(ttl=5, key=lambda *users: users[0][0])
def test_merge_delay_run(*users):
name = ','.join(users)
time.sleep(2)
print("Hello, %s, now is %s" % (name, time.time()))
def do_test():
s = time.time()
print("start : %s" % time.time())
for i in range(100):
# test_delay_run('test', year=i)
test_merge_delay_run('test %s' % i)
test_merge_delay_run('best %s' % i)
end = time.time()
using = end - s
print("end : %s, using: %s" % (end, using))

View File

@ -1,6 +1,6 @@
from django.core.management.base import BaseCommand
from assets.signal_handlers.node_assets_mapping import expire_node_assets_mapping_for_memory
from assets.signal_handlers.node_assets_mapping import expire_node_assets_mapping
from orgs.caches import OrgResourceStatisticsCache
from orgs.models import Organization
@ -10,7 +10,7 @@ def expire_node_assets_mapping():
org_ids = [*org_ids, '00000000-0000-0000-0000-000000000000']
for org_id in org_ids:
expire_node_assets_mapping_for_memory(org_id)
expire_node_assets_mapping(org_id)
def expire_org_resource_statistics_cache():

View File

@ -60,16 +60,18 @@ def on_request_finished_logging_db_query(sender, **kwargs):
method = current_request.method
path = current_request.get_full_path()
# print(">>> [{}] {}".format(method, path))
# for table_name, queries in table_queries.items():
# if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
# continue
# print("- Table: {}".format(table_name))
# for i, query in enumerate(queries, 1):
# sql = query['sql']
# if not sql or not sql.startswith('SELECT'):
# continue
# print('\t{}. {}'.format(i, sql))
print(">>> [{}] {}".format(method, path))
for table_name, queries in table_queries.items():
if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
continue
if len(queries) < 3:
continue
print("- Table: {}".format(table_name))
for i, query in enumerate(queries, 1):
sql = query['sql']
if not sql or not sql.startswith('SELECT'):
continue
print('\t{}. {}'.format(i, sql))
logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters:

View File

@ -1,19 +1,19 @@
from django.db.models.signals import post_save, pre_delete, pre_save, post_delete
from django.dispatch import receiver
from orgs.models import Organization
from assets.models import Node
from accounts.models import Account
from assets.models import Asset, Domain
from assets.models import Node
from common.decorators import merge_delay_run
from common.utils import get_logger
from orgs.caches import OrgResourceStatisticsCache
from orgs.models import Organization
from orgs.utils import current_org
from perms.models import AssetPermission
from audits.models import UserLoginLog
from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding
from terminal.models import Session
from users.models import UserGroup, User
from users.signals import pre_user_leave_org
from terminal.models import Session
from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding
from assets.models import Asset, Domain
from orgs.caches import OrgResourceStatisticsCache
from orgs.utils import current_org
from common.utils import get_logger
logger = get_logger(__name__)
@ -62,42 +62,32 @@ def on_user_delete_refresh_cache(sender, instance, **kwargs):
refresh_all_orgs_user_amount_cache(instance)
# @receiver(m2m_changed, sender=OrganizationMember)
# def on_org_user_changed_refresh_cache(sender, action, instance, reverse, pk_set, **kwargs):
# if not action.startswith(POST_PREFIX):
# return
#
# if reverse:
# orgs = Organization.objects.filter(id__in=pk_set)
# else:
# orgs = [instance]
#
# for org in orgs:
# org_cache = OrgResourceStatisticsCache(org)
# org_cache.expire('users_amount')
# OrgResourceStatisticsCache(Organization.root()).expire('users_amount')
model_cache_field_mapper = {
Node: ['nodes_amount'],
Domain: ['domains_amount'],
UserGroup: ['groups_amount'],
Account: ['accounts_amount'],
RoleBinding: ['users_amount', 'new_users_amount_this_week'],
Asset: ['assets_amount', 'new_assets_amount_this_week'],
AssetPermission: ['asset_perms_amount'],
}
class OrgResourceStatisticsRefreshUtil:
model_cache_field_mapper = {
Node: ['nodes_amount'],
Domain: ['domains_amount'],
UserGroup: ['groups_amount'],
Account: ['accounts_amount'],
RoleBinding: ['users_amount', 'new_users_amount_this_week'],
Asset: ['assets_amount', 'new_assets_amount_this_week'],
AssetPermission: ['asset_perms_amount'],
}
@staticmethod
@merge_delay_run(ttl=5)
def refresh_org_fields(*org_fields):
for org, cache_field_name in org_fields:
OrgResourceStatisticsCache(org).expire(*cache_field_name)
OrgResourceStatisticsCache(Organization.root()).expire(*cache_field_name)
@classmethod
def refresh_if_need(cls, instance):
cache_field_name = cls.model_cache_field_mapper.get(type(instance))
cache_field_name = model_cache_field_mapper.get(type(instance))
if not cache_field_name:
return
OrgResourceStatisticsCache(Organization.root()).expire(*cache_field_name)
if getattr(instance, 'org', None):
OrgResourceStatisticsCache(instance.org).expire(*cache_field_name)
org = getattr(instance, 'org', None)
cls.refresh_org_fields((org, cache_field_name))
@receiver(post_save)

View File

@ -1,13 +1,10 @@
from rest_framework import generics
from rest_framework.permissions import AllowAny, IsAuthenticated
from django.conf import settings
from rest_framework import generics
from rest_framework.permissions import AllowAny
from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info
from common.utils import get_logger, lazyproperty, get_object_or_none
from authentication.models import ConnectionToken
from orgs.utils import tmp_to_root_org
from common.permissions import IsValidUserOrConnectionToken
from common.utils import get_logger, lazyproperty
from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info
from .. import serializers
from ..utils import get_interface_setting_or_default
@ -58,6 +55,3 @@ class PublicSettingApi(OpenPublicSettingApi):
# 提前把异常爆出来
values[name] = getattr(settings, name)
return values

View File

@ -1,4 +1,4 @@
from django.db.models.signals import post_save, post_delete
from django.db.models.signals import post_delete, post_save
from django.dispatch import receiver
from assets.models import Asset

View File

@ -50,7 +50,7 @@ def main():
if resource == 'all':
generator_cls = resource_generator_mapper.values()
else:
generator_cls.push(resource_generator_mapper[resource])
generator_cls.append(resource_generator_mapper[resource])
for _cls in generator_cls:
generator = _cls(org_id=org_id, batch_size=batch_size)

View File

@ -1,7 +1,8 @@
#!/usr/bin/python
from random import seed
import time
from itertools import islice
from random import seed
from orgs.models import Organization
@ -18,7 +19,6 @@ class FakeDataGenerator:
o = Organization.get_instance(org_id, default=Organization.default())
if o:
o.change_to()
print('Current org is: {}'.format(o))
return o
def do_generate(self, batch, batch_size):
@ -38,8 +38,11 @@ class FakeDataGenerator:
batch = list(islice(counter, self.batch_size))
if not batch:
break
start = time.time()
self.do_generate(batch, self.batch_size)
end = time.time()
using = end - start
from_size = created
created += len(batch)
print('Generate %s: %s-%s' % (self.resource, from_size, created))
self.after_generate()
print('Generate %s: %s-%s [{}s]' % (self.resource, from_size, created, using))
self.after_generate()