perf: 优化并发处理

pull/9494/head
ibuler 2023-02-09 20:48:25 +08:00
parent e590518108
commit 37a52c420f
26 changed files with 274 additions and 416 deletions

View File

@ -1,13 +1,13 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from assets.models import Asset
from accounts.const import SecretType, Source
from accounts.models import Account, AccountTemplate
from accounts.tasks import push_accounts_to_assets
from assets.const import Category, AllTypes
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from assets.models import Asset
from common.serializers import SecretReadableMixin, BulkModelSerializer
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from .base import BaseAccountSerializer
@ -47,7 +47,7 @@ class AccountSerializerCreateValidateMixin:
def create(self, validated_data):
push_now = validated_data.pop('push_now', None)
instance = super().create(validated_data, push_now)
instance = super().create(validated_data)
self.push_account(instance, push_now)
return instance

View File

@ -1,4 +1,5 @@
from django.db.models.signals import pre_save, post_save
from django.db.models.signals import post_save
from django.db.models.signals import pre_save
from django.dispatch import receiver
from assets.models import Asset

View File

@ -4,9 +4,9 @@ from assets.tasks.common import generate_automation_execution_data
from common.const.choices import Trigger
def automation_execute_start(task_name, tp, child_snapshot=None):
def automation_execute_start(task_name, tp, task_snapshot=None):
from accounts.models import AutomationExecution
data = generate_automation_execution_data(task_name, tp, child_snapshot)
data = generate_automation_execution_data(task_name, tp, task_snapshot)
while True:
try:

View File

@ -1,13 +1,13 @@
# ~*~ coding: utf-8 ~*~
from celery import shared_task
from django.utils.translation import gettext_noop
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext_noop
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
from assets.models import Node
from common.utils import get_logger
from orgs.utils import org_aware_func
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
__all__ = ['gather_asset_accounts']
logger = get_logger(__name__)
@ -18,11 +18,11 @@ def gather_asset_accounts_util(nodes, task_name):
from accounts.models import GatherAccountsAutomation
task_name = GatherAccountsAutomation.generate_unique_name(task_name)
child_snapshot = {
task_snapshot = {
'nodes': [str(node.id) for node in nodes],
}
tp = AutomationTypes.verify_account
automation_execute_start(task_name, tp, child_snapshot)
automation_execute_start(task_name, tp, task_snapshot)
@shared_task(queue="ansible", verbose_name=_('Gather asset accounts'))

View File

@ -1,10 +1,10 @@
from celery import shared_task
from django.utils.translation import gettext_noop, ugettext_lazy as _
from common.utils import get_logger
from orgs.utils import org_aware_func
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
from common.utils import get_logger
from orgs.utils import org_aware_func
logger = get_logger(__file__)
__all__ = [
@ -13,14 +13,14 @@ __all__ = [
def push_util(account, assets, task_name):
child_snapshot = {
task_snapshot = {
'secret': account.secret,
'secret_type': account.secret_type,
'accounts': [account.username],
'assets': [str(asset.id) for asset in assets],
}
tp = AutomationTypes.push_account
automation_execute_start(task_name, tp, child_snapshot)
automation_execute_start(task_name, tp, task_snapshot)
@org_aware_func("assets")

View File

@ -2,10 +2,10 @@ from celery import shared_task
from django.utils.translation import gettext_noop
from django.utils.translation import ugettext as _
from common.utils import get_logger
from assets.const import GATEWAY_NAME
from accounts.const import AutomationTypes
from accounts.tasks.common import automation_execute_start
from assets.const import GATEWAY_NAME
from common.utils import get_logger
from orgs.utils import org_aware_func
logger = get_logger(__name__)
@ -18,11 +18,11 @@ def verify_connectivity_util(assets, tp, accounts, task_name):
if not assets or not accounts:
return
account_usernames = list(accounts.values_list('username', flat=True))
child_snapshot = {
task_snapshot = {
'accounts': account_usernames,
'assets': [str(asset.id) for asset in assets],
}
automation_execute_start(task_name, tp, child_snapshot)
automation_execute_start(task_name, tp, task_snapshot)
@org_aware_func("assets")

View File

@ -256,8 +256,6 @@ class FamilyMixin:
class NodeAllAssetsMappingMixin:
# Use a new plan
# { org_id: { node_key: [ asset1_id, asset2_id ] } }
orgid_nodekey_assetsid_mapping = defaultdict(dict)
locks_for_get_mapping_from_cache = defaultdict(threading.Lock)
@ -273,20 +271,7 @@ class NodeAllAssetsMappingMixin:
if _mapping:
return _mapping
logger.debug(f'Get node asset mapping from memory failed, acquire thread lock: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
with cls.get_lock(org_id):
logger.debug(f'Acquired thread lock ok. check if mapping is in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
_mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id)
if _mapping:
logger.debug(f'Mapping is already in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return _mapping
_mapping = cls.get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(org_id)
cls.set_node_all_asset_ids_mapping_to_memory(org_id, mapping=_mapping)
return _mapping
@ -302,18 +287,18 @@ class NodeAllAssetsMappingMixin:
cls.orgid_nodekey_assetsid_mapping[org_id] = mapping
@classmethod
def expire_node_all_asset_ids_mapping_from_memory(cls, org_id):
def expire_node_all_asset_ids_memory_mapping(cls, org_id):
org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
@classmethod
def expire_all_orgs_node_all_asset_ids_mapping_from_memory(cls):
def expire_all_orgs_node_all_asset_ids_memory_mapping(cls):
orgs = Organization.objects.all()
org_ids = [str(org.id) for org in orgs]
org_ids.append(Organization.ROOT_ID)
for id in org_ids:
cls.expire_node_all_asset_ids_mapping_from_memory(id)
for i in org_ids:
cls.expire_node_all_asset_ids_memory_mapping(i)
# get order: from memory -> (from cache -> to generate)
@classmethod
@ -332,25 +317,18 @@ class NodeAllAssetsMappingMixin:
return _mapping
_mapping = cls.generate_node_all_asset_ids_mapping(org_id)
cls.set_node_all_asset_ids_mapping_to_cache(org_id=org_id, mapping=_mapping)
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
return _mapping
@classmethod
def get_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
mapping = cache.get(cache_key)
logger.info(f'Get node asset mapping from cache {bool(mapping)}: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return mapping
@classmethod
def set_node_all_asset_ids_mapping_to_cache(cls, org_id, mapping):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
@classmethod
def expire_node_all_asset_ids_mapping_from_cache(cls, org_id):
def expire_node_all_asset_ids_cache_mapping(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.delete(cache_key)
@ -411,6 +389,14 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key)
return Asset.objects.filter(q).distinct()
def get_assets_amount(self):
q = Q(node__key__startswith=f'{self.key}:') | Q(node__key=self.key)
return self.assets.through.objects.filter(q).count()
def get_assets_account_by_children(self):
children = self.get_all_children().values_list()
return self.assets.through.objects.filter(node_id__in=children).count()
@classmethod
def get_node_all_assets_by_key_v2(cls, key):
# 最初的写法是:

View File

@ -124,7 +124,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
]
read_only_fields = [
'category', 'type', 'connectivity',
'date_verified', 'created_by', 'date_created'
'date_verified', 'created_by', 'date_created',
]
fields = fields_small + fields_fk + fields_m2m + read_only_fields
extra_kwargs = {
@ -222,6 +222,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
node_id = request.query_params.get('node_id')
if not node_id:
return []
nodes = Node.objects.filter(id=node_id)
return nodes
def is_valid(self, raise_exception=False):
self._set_protocols_default()

View File

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
#
from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete, pre_save
m2m_changed, pre_delete, post_delete, pre_save, post_save
)
from django.dispatch import receiver
from django.utils.translation import gettext_noop
from assets.models import Asset, Node, Cloud, Device, Host, Web, Database
from assets.tasks import test_assets_connectivity_task
from common.const.signals import POST_ADD, POST_REMOVE, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run
from assets.models import Asset, Node, Host, Database, Device, Web, Cloud
from assets.tasks import test_assets_connectivity_task, gather_assets_facts_task
from common.const.signals import POST_REMOVE, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run, key_by_org
from common.utils import get_logger
logger = get_logger(__file__)
@ -20,15 +20,33 @@ def on_node_pre_save(sender, instance: Node, **kwargs):
instance.parent_key = instance.compute_parent_key()
@merge_delay_run(ttl=10)
@merge_delay_run(ttl=5, key=key_by_org)
def test_assets_connectivity_handler(*assets):
task_name = gettext_noop("Test assets connectivity ")
test_assets_connectivity_task.delay(assets, task_name)
@merge_delay_run(ttl=10)
@merge_delay_run(ttl=5, key=key_by_org)
def gather_assets_facts_handler(*assets):
pass
if not assets:
logger.info("No assets to update hardware info")
return
name = gettext_noop("Gather asset hardware info")
gather_assets_facts_task.delay(assets=assets, task_name=name)
@merge_delay_run(ttl=5, key=key_by_org)
def ensure_asset_has_node(*assets):
asset_ids = [asset.id for asset in assets]
has_ids = Asset.nodes.through.objects \
.filter(asset_id__in=asset_ids) \
.values_list('asset_id', flat=True)
need_ids = set(asset_ids) - set(has_ids)
if not need_ids:
return
org_root = Node.org_root()
org_root.assets.add(*need_ids)
@receiver(post_save, sender=Asset)
@ -42,38 +60,16 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
return
logger.info("Asset create signal recv: {}".format(instance))
ensure_asset_has_node(instance)
# 获取资产硬件信息
test_assets_connectivity_handler([instance])
gather_assets_facts_handler([instance])
# 确保资产存在一个节点
has_node = instance.nodes.all().exists()
if not has_node:
instance.nodes.add(Node.org_root())
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
"""
本操作共访问 4 次数据库
当资产的节点发生变化时或者 当节点的资产关系发生变化时
节点下新增的资产添加到节点关联的系统用户中
"""
if action != POST_ADD:
return
logger.debug("Assets node add signal recv: {}".format(action))
if reverse:
nodes = [instance.key]
asset_ids = pk_set
else:
nodes = Node.objects.filter(pk__in=pk_set).values_list('key', flat=True)
asset_ids = [instance.id]
# 节点资产发生变化时,将资产关联到节点及祖先节点关联的系统用户, 只关注新增的
nodes_ancestors_keys = set()
for node in nodes:
nodes_ancestors_keys.update(Node.get_node_ancestor_keys(node, with_self=True))
auto_info = instance.auto_info
if auto_info.get('ping_enabled'):
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
test_assets_connectivity_handler(instance)
if auto_info.get('gather_facts_enabled'):
logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
gather_assets_facts_handler(instance)
RELATED_NODE_IDS = '_related_node_ids'
@ -82,19 +78,19 @@ RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs):
logger.debug("Asset pre delete signal recv: {}".format(instance))
node_ids = set(Node.objects.filter(
assets=instance
).distinct().values_list('id', flat=True))
node_ids = Node.objects.filter(assets=instance) \
.distinct().values_list('id', flat=True)
setattr(instance, RELATED_NODE_IDS, node_ids)
m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False,
model=Node, pk_set=node_ids, using=using, action=PRE_REMOVE
sender=Asset.nodes.through, instance=instance,
reverse=False, model=Node, pk_set=node_ids,
using=using, action=PRE_REMOVE
)
@receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs):
logger.debug("Asset delete signal recv: {}".format(instance))
logger.debug("Asset post delete signal recv: {}".format(instance))
node_ids = getattr(instance, RELATED_NODE_IDS, None)
if node_ids:
m2m_changed.send(

View File

@ -1,22 +1,20 @@
# -*- coding: utf-8 -*-
#
from operator import add, sub
from django.db.models import Q, F
from django.db.models.signals import m2m_changed
from django.dispatch import receiver
from django.db.models.signals import (
m2m_changed
)
from orgs.utils import ensure_in_real_or_default_org, tmp_to_org
from assets.models import Asset, Node
from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR
from common.decorators import on_transaction_commit, merge_delay_run
from common.utils import get_logger
from assets.models import Asset, Node, compute_parent_key
from assets.locks import NodeTreeUpdateLock
from orgs.utils import tmp_to_org
logger = get_logger(__file__)
@on_transaction_commit
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
@ -25,136 +23,24 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
logger.debug('Recv asset nodes change signal, recompute node assets amount')
mapper = {PRE_ADD: add, POST_REMOVE: sub}
if action not in mapper:
return
operator = mapper[action]
with tmp_to_org(instance.org):
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator)
node_ids = [instance.id]
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator)
node_ids = pk_set
update_node_assets_amount(*node_ids)
class NodeAssetsAmountUtils:
@merge_delay_run(ttl=5)
def update_node_assets_amount(*node_ids):
nodes = list(Node.objects.filter(id__in=node_ids))
logger.info('Update nodes assets amount: {} nodes'.format(len(node_ids)))
for node in nodes:
node.assets_amount = node.get_assets_amount()
@classmethod
def _remove_ancestor_keys(cls, ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
@classmethod
def _is_asset_exists_in_node(cls, asset_pk, node_key):
exists = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).filter(id=asset_pk).exists()
return exists
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_nodes_asset_amount(cls, node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = cls._is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
cls._remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = cls._is_asset_exists_in_node(asset_pk, parent_key)
if exists:
cls._remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
Node.objects.bulk_update(nodes, ['assets_amount'])

View File

@ -2,42 +2,35 @@
#
from django.db.models.signals import (
m2m_changed, post_save, post_delete
post_save, post_delete, m2m_changed
)
from django.dispatch import receiver
from django.utils.functional import LazyObject
from django.utils.functional import lazy
from assets.models import Asset, Node
from assets.models import Node, Asset
from common.decorators import merge_delay_run
from common.signals import django_ready
from common.utils import get_logger
from common.utils.connection import RedisPubSub
from orgs.models import Organization
logger = get_logger(__file__)
logger = get_logger(__name__)
# clear node assets mapping for memory
# ------------------------------------
node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)()
class NodeAssetsMappingForMemoryPubSub(LazyObject):
def _setup(self):
self._wrapped = RedisPubSub('fm.node_all_asset_ids_memory_mapping')
node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub()
def expire_node_assets_mapping_for_memory(org_id):
@merge_delay_run(ttl=5)
def expire_node_assets_mapping(*org_ids):
# 所有进程清除(自己的 memory 数据)
org_id = str(org_id)
root_org_id = Organization.ROOT_ID
# 当前进程清除(cache 数据)
Node.expire_node_all_asset_ids_mapping_from_cache(org_id)
Node.expire_node_all_asset_ids_mapping_from_cache(root_org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
Node.expire_node_all_asset_ids_cache_mapping(root_org_id)
for org_id in set(org_ids):
org_id = str(org_id)
# 当前进程清除(cache 数据)
Node.expire_node_all_asset_ids_cache_mapping(org_id)
node_assets_mapping_pub_sub.publish(org_id)
@receiver(post_save, sender=Node)
@ -50,17 +43,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
need_expire = False
if need_expire:
expire_node_assets_mapping_for_memory(instance.org_id)
expire_node_assets_mapping(instance.org_id)
@receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
expire_node_assets_mapping(instance.org_id)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
logger.debug("Recv asset nodes changed signal, expire memery node asset mapping")
expire_node_assets_mapping(instance.org_id)
@receiver(django_ready)
@ -69,7 +63,7 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
def handle_node_relation_change(org_id):
root_org_id = Organization.ROOT_ID
Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
Node.expire_node_all_asset_ids_mapping_from_memory(root_org_id)
Node.expire_node_all_asset_ids_memory_mapping(org_id)
Node.expire_node_all_asset_ids_memory_mapping(root_org_id)
node_assets_mapping_for_memory_pub_sub.subscribe(handle_node_relation_change)
node_assets_mapping_pub_sub.subscribe(handle_node_relation_change)

View File

@ -8,8 +8,8 @@ from common.const.choices import Trigger
from orgs.utils import current_org
def generate_automation_execution_data(task_name, tp, child_snapshot=None):
child_snapshot = child_snapshot or {}
def generate_automation_execution_data(task_name, tp, task_snapshot=None):
task_snapshot = task_snapshot or {}
from assets.models import BaseAutomation
try:
eid = current_task.request.id
@ -25,13 +25,13 @@ def generate_automation_execution_data(task_name, tp, child_snapshot=None):
automation_instance = BaseAutomation()
snapshot = automation_instance.to_attr_json()
snapshot.update(data)
snapshot.update(child_snapshot)
snapshot.update(task_snapshot)
return {'id': eid, 'snapshot': snapshot}
def quickstart_automation(task_name, tp, child_snapshot=None):
def quickstart_automation(task_name, tp, task_snapshot=None):
from assets.models import AutomationExecution
data = generate_automation_execution_data(task_name, tp, child_snapshot)
data = generate_automation_execution_data(task_name, tp, task_snapshot)
while True:
try:

View File

@ -1,65 +1,55 @@
# -*- coding: utf-8 -*-
#
from itertools import chain
from celery import shared_task
from django.utils.translation import gettext_noop, gettext_lazy as _
from assets.const import AutomationTypes
from common.utils import get_logger
from orgs.utils import org_aware_func
from orgs.utils import tmp_to_org
from .common import quickstart_automation
logger = get_logger(__file__)
__all__ = [
'update_assets_fact_util',
'gather_assets_facts_task',
'update_node_assets_hardware_info_manual',
'update_assets_hardware_info_manual',
]
def update_fact_util(assets=None, nodes=None, task_name=None):
@shared_task(queue="ansible", verbose_name=_('Gather assets facts'))
def gather_assets_facts_task(assets=None, nodes=None, task_name=None):
from assets.models import GatherFactsAutomation
if task_name is None:
task_name = gettext_noop("Update some assets hardware info. ")
task_name = gettext_noop("Gather assets facts")
task_name = GatherFactsAutomation.generate_unique_name(task_name)
nodes = nodes or []
assets = assets or []
child_snapshot = {
resources = chain(assets, nodes)
if not resources:
raise ValueError("nodes or assets must be given")
org_id = list(resources)[0].org_id
task_snapshot = {
'assets': [str(asset.id) for asset in assets],
'nodes': [str(node.id) for node in nodes],
}
tp = AutomationTypes.gather_facts
quickstart_automation(task_name, tp, child_snapshot)
with tmp_to_org(org_id):
quickstart_automation(task_name, tp, task_snapshot)
@org_aware_func('assets')
def update_assets_fact_util(assets=None, task_name=None):
if assets is None:
logger.info("No assets to update hardware info")
return
update_fact_util(assets=assets, task_name=task_name)
@org_aware_func('nodes')
def update_nodes_fact_util(nodes=None, task_name=None):
if nodes is None:
logger.info("No nodes to update hardware info")
return
update_fact_util(nodes=nodes, task_name=task_name)
@shared_task(queue="ansible", verbose_name=_('Manually update the hardware information of assets'))
def update_assets_hardware_info_manual(asset_ids):
from assets.models import Asset
assets = Asset.objects.filter(id__in=asset_ids)
task_name = gettext_noop("Update assets hardware info: ")
update_assets_fact_util(assets=assets, task_name=task_name)
gather_assets_facts_task.delay(assets=assets, task_name=task_name)
@shared_task(queue="ansible", verbose_name=_('Manually update the hardware information of assets under a node'))
def update_node_assets_hardware_info_manual(node_id):
from assets.models import Node
node = Node.objects.get(id=node_id)
task_name = gettext_noop("Update node asset hardware information: ")
update_nodes_fact_util(nodes=[node], task_name=task_name)
gather_assets_facts_task.delay(nodes=[node], task_name=task_name)

View File

@ -24,8 +24,8 @@ def test_assets_connectivity_task(assets, task_name=None):
task_name = gettext_noop("Test assets connectivity ")
task_name = PingAutomation.generate_unique_name(task_name)
child_snapshot = {'assets': [str(asset.id) for asset in assets]}
quickstart_automation(task_name, AutomationTypes.ping, child_snapshot)
task_snapshot = {'assets': [str(asset.id) for asset in assets]}
quickstart_automation(task_name, AutomationTypes.ping, task_snapshot)
def test_assets_connectivity_manual(asset_ids):

View File

@ -6,7 +6,7 @@ from django.apps import apps
from django.conf import settings
from django.contrib.auth import BACKEND_SESSION_KEY
from django.db import transaction
from django.db.models.signals import post_save, pre_save, m2m_changed, pre_delete
from django.db.models.signals import pre_delete, pre_save, m2m_changed, post_save
from django.dispatch import receiver
from django.utils import timezone, translation
from django.utils.functional import LazyObject
@ -75,6 +75,7 @@ def on_m2m_changed(sender, action, instance, reverse, model, pk_set, **kwargs):
return
if not instance:
return
return
resource_type = instance._meta.verbose_name
current_instance = model_to_dict(instance, include_model_fields=False)
@ -151,6 +152,7 @@ def on_object_pre_create_or_update(sender, instance=None, raw=False, using=None,
@receiver(post_save)
def on_object_created_or_update(sender, instance=None, created=False, update_fields=None, **kwargs):
return
ok = signal_of_operate_log_whether_continue(
sender, instance, created, update_fields
)

View File

@ -327,6 +327,7 @@ class AuthACLMixin:
get_request_ip: Callable
def _check_login_acl(self, user, ip):
raise ValueError('Not implement')
# ACL 限制用户登录
acl = LoginACL.match(user, ip)
if not acl:

View File

@ -1,14 +1,8 @@
# -*- coding: utf-8 -*-
#
from django.utils.translation import ugettext_lazy as _
create_success_msg = _("%(name)s was created successfully")
update_success_msg = _("%(name)s was updated successfully")
FILE_END_GUARD = ">>> Content End <<<"
celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable
# https://docs.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties
LDAP_AD_ACCOUNT_DISABLE = 2
from .choices import *
from .common import *
from .crontab import *
from .http import *
from .signals import *

View File

@ -0,0 +1,11 @@
from django.utils.translation import ugettext_lazy as _
create_success_msg = _("%(name)s was created successfully")
update_success_msg = _("%(name)s was updated successfully")
FILE_END_GUARD = ">>> Content End <<<"
celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable
# https://docs.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties
LDAP_AD_ACCOUNT_DISABLE = 2

View File

@ -9,6 +9,8 @@ from concurrent.futures import ThreadPoolExecutor
from django.core.cache import cache
from django.db import transaction
from .utils import logger
def on_transaction_commit(func):
"""
@ -34,54 +36,64 @@ class Singleton(object):
return self._instance[self._cls]
def _run_func_if_is_last(ttl, func, *args, **kwargs):
ix = uuid.uuid4().__str__()
key = f'DELAY_RUN_{func.__name__}'
cache.set(key, ix, ttl)
st = (ttl - 2 > 1) and ttl - 2 or 2
time.sleep(st)
got = cache.get(key, None)
def default_suffix_key(*args, **kwargs):
return 'default'
if ix == got:
func(*args, **kwargs)
def key_by_org(*args, **kwargs):
return args[0].org_id
def _run_func_if_is_last(ttl, suffix_key, org, func, *args, **kwargs):
from orgs.utils import set_current_org
try:
set_current_org(org)
uid = uuid.uuid4().__str__()
suffix_key_func = suffix_key if suffix_key else default_suffix_key
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
key = f'DELAY_RUN_{func_name}_{key_suffix}'
cache.set(key, uid, ttl)
st = (ttl - 2 > 1) and ttl - 2 or 2
time.sleep(st)
ret = cache.get(key, None)
if uid == ret:
func(*args, **kwargs)
except Exception as e:
logger.error('delay run error: %s' % e)
executor = ThreadPoolExecutor(10)
def delay_run(ttl=5):
def delay_run(ttl=5, key=None):
"""
延迟执行函数, ttl 秒内, 只执行最后一次
:param ttl:
:param key: 是否合并参数, 一个 callback
:return:
"""
def inner(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
executor.submit(_run_func_if_is_last, ttl, func, *args, **kwargs)
from orgs.utils import get_current_org
org = get_current_org()
executor.submit(_run_func_if_is_last, ttl, key, org, func, *args, **kwargs)
return wrapper
return inner
def _merge_run(ttl, func, *args, **kwargs):
if not args or not isinstance(args[0], (list, tuple)):
raise ValueError('args[0] must be list or tuple')
key = f'DELAY_MERGE_RUN_{func.__name__}'
ix = uuid.uuid4().__str__()
value = cache.get(key, [])
value.extend(args[0])
st = (ttl - 2 > 1) and ttl - 2 or 2
time.sleep(st)
got = cache.get(key, None)
if ix == got:
func(*args, **kwargs)
def merge_delay_run(ttl):
def merge_delay_run(ttl, key=None):
"""
合并 func 参数延迟执行, ttl 秒内, 只执行最后一次
func 参数必须是 *args
:param ttl:
:param key: 是否合并参数, 一个 callback
:return:
"""
@ -93,42 +105,36 @@ def merge_delay_run(ttl):
if not str(param).startswith('*'):
raise ValueError('func args must be startswith *: %s' % func.__name__)
suffix_key_func = key if key else default_suffix_key
@functools.wraps(func)
def wrapper(*args):
key = f'DELAY_MERGE_RUN_{func.__name__}'
values = cache.get(key, [])
key_suffix = suffix_key_func(*args)
func_name = f'{func.__module__}_{func.__name__}'
cache_key = f'DELAY_MERGE_RUN_{func_name}_{key_suffix}'
values = cache.get(cache_key, [])
new_arg = [*values, *args]
cache.set(key, new_arg, ttl)
return delay_run(ttl)(func)(*new_arg)
cache.set(cache_key, new_arg, ttl)
return delay_run(ttl, suffix_key_func)(func)(*new_arg)
return wrapper
return inner
def delay_run(ttl=5):
"""
延迟执行函数, ttl 秒内, 只执行最后一次
:param ttl:
:return:
"""
def inner(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
executor.submit(_run_func_if_is_last, ttl, func, *args, **kwargs)
return wrapper
return inner
@delay_run(ttl=10)
@delay_run(ttl=5)
def test_delay_run(username, year=2000):
print("Hello, %s, now is %s" % (username, year))
@merge_delay_run(ttl=10)
@merge_delay_run(ttl=5, key=lambda *users: users[0][0])
def test_merge_delay_run(*users):
name = ','.join(users)
print("Hello, %s, now is %s" % (name, time.time()))
def do_test():
for i in range(10):
# test_delay_run('test', year=i)
test_merge_delay_run('test %s' % i)
test_merge_delay_run('best %s' % i)

View File

@ -1,6 +1,6 @@
from django.core.management.base import BaseCommand
from assets.signal_handlers.node_assets_mapping import expire_node_assets_mapping_for_memory
from assets.signal_handlers.node_assets_mapping import expire_node_assets_mapping
from orgs.caches import OrgResourceStatisticsCache
from orgs.models import Organization
@ -10,7 +10,7 @@ def expire_node_assets_mapping():
org_ids = [*org_ids, '00000000-0000-0000-0000-000000000000']
for org_id in org_ids:
expire_node_assets_mapping_for_memory(org_id)
expire_node_assets_mapping(org_id)
def expire_org_resource_statistics_cache():

View File

@ -60,16 +60,18 @@ def on_request_finished_logging_db_query(sender, **kwargs):
method = current_request.method
path = current_request.get_full_path()
# print(">>> [{}] {}".format(method, path))
# for table_name, queries in table_queries.items():
# if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
# continue
# print("- Table: {}".format(table_name))
# for i, query in enumerate(queries, 1):
# sql = query['sql']
# if not sql or not sql.startswith('SELECT'):
# continue
# print('\t{}. {}'.format(i, sql))
print(">>> [{}] {}".format(method, path))
for table_name, queries in table_queries.items():
if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
continue
if len(queries) < 3:
continue
print("- Table: {}".format(table_name))
for i, query in enumerate(queries, 1):
sql = query['sql']
if not sql or not sql.startswith('SELECT'):
continue
print('\t{}. {}'.format(i, sql))
logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters:

View File

@ -1,19 +1,19 @@
from django.db.models.signals import post_save, pre_delete, pre_save, post_delete
from django.dispatch import receiver
from orgs.models import Organization
from assets.models import Node
from accounts.models import Account
from assets.models import Asset, Domain
from assets.models import Node
from common.decorators import merge_delay_run
from common.utils import get_logger
from orgs.caches import OrgResourceStatisticsCache
from orgs.models import Organization
from orgs.utils import current_org
from perms.models import AssetPermission
from audits.models import UserLoginLog
from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding
from terminal.models import Session
from users.models import UserGroup, User
from users.signals import pre_user_leave_org
from terminal.models import Session
from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding
from assets.models import Asset, Domain
from orgs.caches import OrgResourceStatisticsCache
from orgs.utils import current_org
from common.utils import get_logger
logger = get_logger(__name__)
@ -58,42 +58,32 @@ def on_user_delete_refresh_cache(sender, instance, **kwargs):
refresh_user_amount_cache(instance)
# @receiver(m2m_changed, sender=OrganizationMember)
# def on_org_user_changed_refresh_cache(sender, action, instance, reverse, pk_set, **kwargs):
# if not action.startswith(POST_PREFIX):
# return
#
# if reverse:
# orgs = Organization.objects.filter(id__in=pk_set)
# else:
# orgs = [instance]
#
# for org in orgs:
# org_cache = OrgResourceStatisticsCache(org)
# org_cache.expire('users_amount')
# OrgResourceStatisticsCache(Organization.root()).expire('users_amount')
model_cache_field_mapper = {
Node: ['nodes_amount'],
Domain: ['domains_amount'],
UserGroup: ['groups_amount'],
Account: ['accounts_amount'],
RoleBinding: ['users_amount', 'new_users_amount_this_week'],
Asset: ['assets_amount', 'new_assets_amount_this_week'],
AssetPermission: ['asset_perms_amount'],
}
class OrgResourceStatisticsRefreshUtil:
model_cache_field_mapper = {
Node: ['nodes_amount'],
Domain: ['domains_amount'],
UserGroup: ['groups_amount'],
Account: ['accounts_amount'],
RoleBinding: ['users_amount', 'new_users_amount_this_week'],
Asset: ['assets_amount', 'new_assets_amount_this_week'],
AssetPermission: ['asset_perms_amount'],
}
@staticmethod
@merge_delay_run(ttl=5)
def refresh_org_fields(*org_fields):
for org, cache_field_name in org_fields:
OrgResourceStatisticsCache(org).expire(*cache_field_name)
OrgResourceStatisticsCache(Organization.root()).expire(*cache_field_name)
@classmethod
def refresh_if_need(cls, instance):
cache_field_name = cls.model_cache_field_mapper.get(type(instance))
cache_field_name = model_cache_field_mapper.get(type(instance))
if not cache_field_name:
return
OrgResourceStatisticsCache(Organization.root()).expire(*cache_field_name)
if getattr(instance, 'org', None):
OrgResourceStatisticsCache(instance.org).expire(*cache_field_name)
org = getattr(instance, 'org', None)
cls.refresh_org_fields((org, cache_field_name))
@receiver(post_save)

View File

@ -1,13 +1,10 @@
from rest_framework import generics
from rest_framework.permissions import AllowAny, IsAuthenticated
from django.conf import settings
from rest_framework import generics
from rest_framework.permissions import AllowAny
from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info
from common.utils import get_logger, lazyproperty, get_object_or_none
from authentication.models import ConnectionToken
from orgs.utils import tmp_to_root_org
from common.permissions import IsValidUserOrConnectionToken
from common.utils import get_logger, lazyproperty
from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info
from .. import serializers
from ..utils import get_interface_setting_or_default
@ -58,6 +55,3 @@ class PublicSettingApi(OpenPublicSettingApi):
# 提前把异常爆出来
values[name] = getattr(settings, name)
return values

View File

@ -1,4 +1,4 @@
from django.db.models.signals import post_save, post_delete
from django.db.models.signals import post_delete, post_save
from django.dispatch import receiver
from assets.models import Asset

View File

@ -50,7 +50,7 @@ def main():
if resource == 'all':
generator_cls = resource_generator_mapper.values()
else:
generator_cls.push(resource_generator_mapper[resource])
generator_cls.append(resource_generator_mapper[resource])
for _cls in generator_cls:
generator = _cls(org_id=org_id, batch_size=batch_size)

View File

@ -1,7 +1,8 @@
#!/usr/bin/python
from random import seed
import time
from itertools import islice
from random import seed
from orgs.models import Organization
@ -18,7 +19,6 @@ class FakeDataGenerator:
o = Organization.get_instance(org_id, default=Organization.default())
if o:
o.change_to()
print('Current org is: {}'.format(o))
return o
def do_generate(self, batch, batch_size):
@ -38,8 +38,11 @@ class FakeDataGenerator:
batch = list(islice(counter, self.batch_size))
if not batch:
break
start = time.time()
self.do_generate(batch, self.batch_size)
end = time.time()
using = end - start
from_size = created
created += len(batch)
print('Generate %s: %s-%s' % (self.resource, from_size, created))
self.after_generate()
print('Generate %s: %s-%s [{}s]' % (self.resource, from_size, created, using))
self.after_generate()