diff --git a/apps/accounts/serializers/account/account.py b/apps/accounts/serializers/account/account.py index 3017b950c..3db4cf445 100644 --- a/apps/accounts/serializers/account/account.py +++ b/apps/accounts/serializers/account/account.py @@ -1,13 +1,13 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers -from assets.models import Asset from accounts.const import SecretType, Source from accounts.models import Account, AccountTemplate from accounts.tasks import push_accounts_to_assets from assets.const import Category, AllTypes -from common.serializers.fields import ObjectRelatedField, LabeledChoiceField +from assets.models import Asset from common.serializers import SecretReadableMixin, BulkModelSerializer +from common.serializers.fields import ObjectRelatedField, LabeledChoiceField from .base import BaseAccountSerializer diff --git a/apps/accounts/signal_handlers.py b/apps/accounts/signal_handlers.py index d2ac73cc4..1323a17be 100644 --- a/apps/accounts/signal_handlers.py +++ b/apps/accounts/signal_handlers.py @@ -1,4 +1,5 @@ -from django.db.models.signals import pre_save, post_save +from django.db.models.signals import post_save +from django.db.models.signals import pre_save from django.dispatch import receiver from assets.models import Asset diff --git a/apps/accounts/tasks/common.py b/apps/accounts/tasks/common.py index 569a487a3..5c8aeedf5 100644 --- a/apps/accounts/tasks/common.py +++ b/apps/accounts/tasks/common.py @@ -4,9 +4,9 @@ from assets.tasks.common import generate_automation_execution_data from common.const.choices import Trigger -def automation_execute_start(task_name, tp, child_snapshot=None): +def automation_execute_start(task_name, tp, task_snapshot=None): from accounts.models import AutomationExecution - data = generate_automation_execution_data(task_name, tp, child_snapshot) + data = generate_automation_execution_data(task_name, tp, task_snapshot) while True: try: diff --git a/apps/accounts/tasks/gather_accounts.py b/apps/accounts/tasks/gather_accounts.py index dbfbe981e..434c2c0d8 100644 --- a/apps/accounts/tasks/gather_accounts.py +++ b/apps/accounts/tasks/gather_accounts.py @@ -1,13 +1,13 @@ # ~*~ coding: utf-8 ~*~ from celery import shared_task -from django.utils.translation import gettext_noop from django.utils.translation import gettext_lazy as _ +from django.utils.translation import gettext_noop +from accounts.const import AutomationTypes +from accounts.tasks.common import automation_execute_start from assets.models import Node from common.utils import get_logger from orgs.utils import org_aware_func -from accounts.const import AutomationTypes -from accounts.tasks.common import automation_execute_start __all__ = ['gather_asset_accounts'] logger = get_logger(__name__) @@ -18,11 +18,11 @@ def gather_asset_accounts_util(nodes, task_name): from accounts.models import GatherAccountsAutomation task_name = GatherAccountsAutomation.generate_unique_name(task_name) - child_snapshot = { + task_snapshot = { 'nodes': [str(node.id) for node in nodes], } tp = AutomationTypes.verify_account - automation_execute_start(task_name, tp, child_snapshot) + automation_execute_start(task_name, tp, task_snapshot) @shared_task(queue="ansible", verbose_name=_('Gather asset accounts')) diff --git a/apps/accounts/tasks/push_account.py b/apps/accounts/tasks/push_account.py index 4ae09d72b..2b3e9279a 100644 --- a/apps/accounts/tasks/push_account.py +++ b/apps/accounts/tasks/push_account.py @@ -1,10 +1,10 @@ from celery import shared_task from django.utils.translation import gettext_noop, ugettext_lazy as _ -from common.utils import get_logger -from orgs.utils import org_aware_func from accounts.const import AutomationTypes from accounts.tasks.common import automation_execute_start +from common.utils import get_logger +from orgs.utils import org_aware_func logger = get_logger(__file__) __all__ = [ @@ -13,14 +13,14 @@ __all__ = [ def push_util(account, assets, task_name): - child_snapshot = { + task_snapshot = { 'secret': account.secret, 'secret_type': account.secret_type, 'accounts': [account.username], 'assets': [str(asset.id) for asset in assets], } tp = AutomationTypes.push_account - automation_execute_start(task_name, tp, child_snapshot) + automation_execute_start(task_name, tp, task_snapshot) @org_aware_func("assets") diff --git a/apps/accounts/tasks/verify_account.py b/apps/accounts/tasks/verify_account.py index 4c478ce89..219c7fdaa 100644 --- a/apps/accounts/tasks/verify_account.py +++ b/apps/accounts/tasks/verify_account.py @@ -2,10 +2,10 @@ from celery import shared_task from django.utils.translation import gettext_noop from django.utils.translation import ugettext as _ -from common.utils import get_logger -from assets.const import GATEWAY_NAME from accounts.const import AutomationTypes from accounts.tasks.common import automation_execute_start +from assets.const import GATEWAY_NAME +from common.utils import get_logger from orgs.utils import org_aware_func logger = get_logger(__name__) @@ -18,11 +18,11 @@ def verify_connectivity_util(assets, tp, accounts, task_name): if not assets or not accounts: return account_usernames = list(accounts.values_list('username', flat=True)) - child_snapshot = { + task_snapshot = { 'accounts': account_usernames, 'assets': [str(asset.id) for asset in assets], } - automation_execute_start(task_name, tp, child_snapshot) + automation_execute_start(task_name, tp, task_snapshot) @org_aware_func("assets") diff --git a/apps/assets/models/node.py b/apps/assets/models/node.py index 5d1e633f5..8f6ee45f3 100644 --- a/apps/assets/models/node.py +++ b/apps/assets/models/node.py @@ -256,8 +256,6 @@ class FamilyMixin: class NodeAllAssetsMappingMixin: - # Use a new plan - # { org_id: { node_key: [ asset1_id, asset2_id ] } } orgid_nodekey_assetsid_mapping = defaultdict(dict) locks_for_get_mapping_from_cache = defaultdict(threading.Lock) @@ -273,20 +271,7 @@ class NodeAllAssetsMappingMixin: if _mapping: return _mapping - logger.debug(f'Get node asset mapping from memory failed, acquire thread lock: ' - f'thread={threading.get_ident()} ' - f'org_id={org_id}') with cls.get_lock(org_id): - logger.debug(f'Acquired thread lock ok. check if mapping is in memory now: ' - f'thread={threading.get_ident()} ' - f'org_id={org_id}') - _mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id) - if _mapping: - logger.debug(f'Mapping is already in memory now: ' - f'thread={threading.get_ident()} ' - f'org_id={org_id}') - return _mapping - _mapping = cls.get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(org_id) cls.set_node_all_asset_ids_mapping_to_memory(org_id, mapping=_mapping) return _mapping @@ -302,18 +287,18 @@ class NodeAllAssetsMappingMixin: cls.orgid_nodekey_assetsid_mapping[org_id] = mapping @classmethod - def expire_node_all_asset_ids_mapping_from_memory(cls, org_id): + def expire_node_all_asset_ids_memory_mapping(cls, org_id): org_id = str(org_id) cls.orgid_nodekey_assetsid_mapping.pop(org_id, None) @classmethod - def expire_all_orgs_node_all_asset_ids_mapping_from_memory(cls): + def expire_all_orgs_node_all_asset_ids_memory_mapping(cls): orgs = Organization.objects.all() org_ids = [str(org.id) for org in orgs] org_ids.append(Organization.ROOT_ID) - for id in org_ids: - cls.expire_node_all_asset_ids_mapping_from_memory(id) + for i in org_ids: + cls.expire_node_all_asset_ids_memory_mapping(i) # get order: from memory -> (from cache -> to generate) @classmethod @@ -332,25 +317,18 @@ class NodeAllAssetsMappingMixin: return _mapping _mapping = cls.generate_node_all_asset_ids_mapping(org_id) - cls.set_node_all_asset_ids_mapping_to_cache(org_id=org_id, mapping=_mapping) + cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id) + cache.set(cache_key, mapping, timeout=None) return _mapping @classmethod def get_node_all_asset_ids_mapping_from_cache(cls, org_id): cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id) mapping = cache.get(cache_key) - logger.info(f'Get node asset mapping from cache {bool(mapping)}: ' - f'thread={threading.get_ident()} ' - f'org_id={org_id}') return mapping @classmethod - def set_node_all_asset_ids_mapping_to_cache(cls, org_id, mapping): - cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id) - cache.set(cache_key, mapping, timeout=None) - - @classmethod - def expire_node_all_asset_ids_mapping_from_cache(cls, org_id): + def expire_node_all_asset_ids_cache_mapping(cls, org_id): cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id) cache.delete(cache_key) @@ -411,6 +389,14 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin): q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key) return Asset.objects.filter(q).distinct() + def get_assets_amount(self): + q = Q(node__key__startswith=f'{self.key}:') | Q(node__key=self.key) + return self.assets.through.objects.filter(q).count() + + def get_assets_account_by_children(self): + children = self.get_all_children().values_list() + return self.assets.through.objects.filter(node_id__in=children).count() + @classmethod def get_node_all_assets_by_key_v2(cls, key): # 最初的写法是: diff --git a/apps/assets/serializers/asset/common.py b/apps/assets/serializers/asset/common.py index 732fbe55d..35a4af460 100644 --- a/apps/assets/serializers/asset/common.py +++ b/apps/assets/serializers/asset/common.py @@ -130,7 +130,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali ] read_only_fields = [ 'category', 'type', 'connectivity', - 'date_verified', 'created_by', 'date_created' + 'date_verified', 'created_by', 'date_created', ] fields = fields_small + fields_fk + fields_m2m + read_only_fields extra_kwargs = { @@ -228,6 +228,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali node_id = request.query_params.get('node_id') if not node_id: return [] + nodes = Node.objects.filter(id=node_id) + return nodes def is_valid(self, raise_exception=False): self._set_protocols_default() diff --git a/apps/assets/signal_handlers/asset.py b/apps/assets/signal_handlers/asset.py index 55851f9c6..94aa86ff2 100644 --- a/apps/assets/signal_handlers/asset.py +++ b/apps/assets/signal_handlers/asset.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- # from django.db.models.signals import ( - post_save, m2m_changed, pre_delete, post_delete, pre_save + m2m_changed, pre_delete, post_delete, pre_save, post_save ) from django.dispatch import receiver from django.utils.translation import gettext_noop -from assets.models import Asset, Node, Cloud, Device, Host, Web, Database -from assets.tasks import test_assets_connectivity_task -from common.const.signals import POST_ADD, POST_REMOVE, PRE_REMOVE -from common.decorators import on_transaction_commit, merge_delay_run +from assets.models import Asset, Node, Host, Database, Device, Web, Cloud +from assets.tasks import test_assets_connectivity_task, gather_assets_facts_task +from common.const.signals import POST_REMOVE, PRE_REMOVE +from common.decorators import on_transaction_commit, merge_delay_run, key_by_org from common.utils import get_logger logger = get_logger(__file__) @@ -20,15 +20,33 @@ def on_node_pre_save(sender, instance: Node, **kwargs): instance.parent_key = instance.compute_parent_key() -@merge_delay_run(ttl=10) +@merge_delay_run(ttl=5, key=key_by_org) def test_assets_connectivity_handler(*assets): task_name = gettext_noop("Test assets connectivity ") test_assets_connectivity_task.delay(assets, task_name) -@merge_delay_run(ttl=10) +@merge_delay_run(ttl=5, key=key_by_org) def gather_assets_facts_handler(*assets): - pass + if not assets: + logger.info("No assets to update hardware info") + return + name = gettext_noop("Gather asset hardware info") + gather_assets_facts_task.delay(assets=assets, task_name=name) + + +@merge_delay_run(ttl=5, key=key_by_org) +def ensure_asset_has_node(*assets): + asset_ids = [asset.id for asset in assets] + has_ids = Asset.nodes.through.objects \ + .filter(asset_id__in=asset_ids) \ + .values_list('asset_id', flat=True) + need_ids = set(asset_ids) - set(has_ids) + if not need_ids: + return + + org_root = Node.org_root() + org_root.assets.add(*need_ids) @receiver(post_save, sender=Asset) @@ -42,38 +60,16 @@ def on_asset_create(sender, instance=None, created=False, **kwargs): return logger.info("Asset create signal recv: {}".format(instance)) + ensure_asset_has_node(instance) + # 获取资产硬件信息 - test_assets_connectivity_handler([instance]) - gather_assets_facts_handler([instance]) - - # 确保资产存在一个节点 - has_node = instance.nodes.all().exists() - if not has_node: - instance.nodes.add(Node.org_root()) - - -@receiver(m2m_changed, sender=Asset.nodes.through) -def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs): - """ - 本操作共访问 4 次数据库 - - 当资产的节点发生变化时,或者 当节点的资产关系发生变化时, - 节点下新增的资产,添加到节点关联的系统用户中 - """ - if action != POST_ADD: - return - logger.debug("Assets node add signal recv: {}".format(action)) - if reverse: - nodes = [instance.key] - asset_ids = pk_set - else: - nodes = Node.objects.filter(pk__in=pk_set).values_list('key', flat=True) - asset_ids = [instance.id] - - # 节点资产发生变化时,将资产关联到节点及祖先节点关联的系统用户, 只关注新增的 - nodes_ancestors_keys = set() - for node in nodes: - nodes_ancestors_keys.update(Node.get_node_ancestor_keys(node, with_self=True)) + auto_info = instance.auto_info + if auto_info.get('ping_enabled'): + logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name)) + test_assets_connectivity_handler(instance) + if auto_info.get('gather_facts_enabled'): + logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name)) + gather_assets_facts_handler(instance) RELATED_NODE_IDS = '_related_node_ids' @@ -82,19 +78,19 @@ RELATED_NODE_IDS = '_related_node_ids' @receiver(pre_delete, sender=Asset) def on_asset_delete(instance: Asset, using, **kwargs): logger.debug("Asset pre delete signal recv: {}".format(instance)) - node_ids = set(Node.objects.filter( - assets=instance - ).distinct().values_list('id', flat=True)) + node_ids = Node.objects.filter(assets=instance) \ + .distinct().values_list('id', flat=True) setattr(instance, RELATED_NODE_IDS, node_ids) m2m_changed.send( - sender=Asset.nodes.through, instance=instance, reverse=False, - model=Node, pk_set=node_ids, using=using, action=PRE_REMOVE + sender=Asset.nodes.through, instance=instance, + reverse=False, model=Node, pk_set=node_ids, + using=using, action=PRE_REMOVE ) @receiver(post_delete, sender=Asset) def on_asset_post_delete(instance: Asset, using, **kwargs): - logger.debug("Asset delete signal recv: {}".format(instance)) + logger.debug("Asset post delete signal recv: {}".format(instance)) node_ids = getattr(instance, RELATED_NODE_IDS, None) if node_ids: m2m_changed.send( diff --git a/apps/assets/signal_handlers/node_assets_amount.py b/apps/assets/signal_handlers/node_assets_amount.py index cb9d2fcc1..314171a71 100644 --- a/apps/assets/signal_handlers/node_assets_amount.py +++ b/apps/assets/signal_handlers/node_assets_amount.py @@ -1,22 +1,21 @@ # -*- coding: utf-8 -*- # from operator import add, sub -from django.db.models import Q, F + +from django.db.models.signals import m2m_changed from django.dispatch import receiver -from django.db.models.signals import ( - m2m_changed -) -from orgs.utils import ensure_in_real_or_default_org, tmp_to_org +from assets.models import Asset, Node from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR +from common.decorators import on_transaction_commit, merge_delay_run from common.utils import get_logger -from assets.models import Asset, Node, compute_parent_key -from assets.locks import NodeTreeUpdateLock - +from orgs.utils import tmp_to_org +from ..tasks import check_node_assets_amount_task logger = get_logger(__file__) +@on_transaction_commit @receiver(m2m_changed, sender=Asset.nodes.through) def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs): # 不允许 `pre_clear` ,因为该信号没有 `pk_set` @@ -25,136 +24,29 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs): if action in refused: raise ValueError - mapper = { - PRE_ADD: add, - POST_REMOVE: sub - } + logger.debug('Recv asset nodes change signal, recompute node assets amount') + mapper = {PRE_ADD: add, POST_REMOVE: sub} if action not in mapper: return - operator = mapper[action] - with tmp_to_org(instance.org): if reverse: - node: Node = instance - asset_pk_set = set(pk_set) - NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator) + node_ids = [instance.id] else: - asset_pk = instance.id - # 与资产直接关联的节点 - node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True)) - NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator) + node_ids = pk_set + update_nodes_assets_amount(*node_ids) -class NodeAssetsAmountUtils: +@merge_delay_run(ttl=5) +def update_nodes_assets_amount(*node_ids): + nodes = list(Node.objects.filter(id__in=node_ids)) + logger.info('Update nodes assets amount: {} nodes'.format(len(node_ids))) - @classmethod - def _remove_ancestor_keys(cls, ancestor_key, tree_set): - # 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环 - # 判断是否在集合里,来区分是否已被处理过 - while ancestor_key and ancestor_key in tree_set: - tree_set.remove(ancestor_key) - ancestor_key = compute_parent_key(ancestor_key) + if len(node_ids) > 100: + check_node_assets_amount_task.delay() + return - @classmethod - def _is_asset_exists_in_node(cls, asset_pk, node_key): - exists = Asset.objects.filter( - Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key) - ).filter(id=asset_pk).exists() - return exists + for node in nodes: + node.assets_amount = node.get_assets_amount() - @classmethod - @ensure_in_real_or_default_org - @NodeTreeUpdateLock() - def update_nodes_asset_amount(cls, node_keys, asset_pk, operator): - """ - 一个资产与多个节点关系变化时,更新计数 - - :param node_keys: 节点 id 的集合 - :param asset_pk: 资产 id - :param operator: 操作 - """ - - # 所有相关节点的祖先节点,组成一棵局部树 - ancestor_keys = set() - for key in node_keys: - ancestor_keys.update(Node.get_node_ancestor_keys(key)) - - # 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉 - node_keys -= ancestor_keys - - to_update_keys = [] - for key in node_keys: - # 遍历相关节点,处理它及其祖先节点 - # 查询该节点是否包含待处理资产 - exists = cls._is_asset_exists_in_node(asset_pk, key) - parent_key = compute_parent_key(key) - - if exists: - # 如果资产在该节点,那么他及其祖先节点都不用处理 - cls._remove_ancestor_keys(parent_key, ancestor_keys) - continue - else: - # 不存在,要更新本节点 - to_update_keys.append(key) - # 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环 - # 判断是否在集合里,来区分是否已被处理过 - while parent_key and parent_key in ancestor_keys: - exists = cls._is_asset_exists_in_node(asset_pk, parent_key) - if exists: - cls._remove_ancestor_keys(parent_key, ancestor_keys) - break - else: - to_update_keys.append(parent_key) - ancestor_keys.remove(parent_key) - parent_key = compute_parent_key(parent_key) - - Node.objects.filter(key__in=to_update_keys).update( - assets_amount=operator(F('assets_amount'), 1) - ) - - @classmethod - @ensure_in_real_or_default_org - @NodeTreeUpdateLock() - def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add): - """ - 一个节点与多个资产关系变化时,更新计数 - - :param node: 节点实例 - :param asset_pk_set: 资产的`id`集合, 内部不会修改该值 - :param operator: 操作 - * -> Node - # -> Asset - - * [3] - / \ - * * [2] - / \ - * * [1] - / / \ - * [a] # # [b] - - """ - # 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key` - ancestor_keys = node.get_ancestor_keys(with_self=True) - ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key') - to_update = [] - for ancestor in ancestors: - # 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3] - # 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的 - # 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量 - - asset_pk_set -= set(Asset.objects.filter( - id__in=asset_pk_set - ).filter( - Q(nodes__key__istartswith=f'{ancestor.key}:') | - Q(nodes__key=ancestor.key) - ).distinct().values_list('id', flat=True)) - if not asset_pk_set: - # 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量 - # 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用 - # 处理了 - break - ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set)) - to_update.append(ancestor) - Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key')) + Node.objects.bulk_update(nodes, ['assets_amount']) diff --git a/apps/assets/signal_handlers/node_assets_mapping.py b/apps/assets/signal_handlers/node_assets_mapping.py index 27640fc76..54855a29a 100644 --- a/apps/assets/signal_handlers/node_assets_mapping.py +++ b/apps/assets/signal_handlers/node_assets_mapping.py @@ -2,42 +2,35 @@ # from django.db.models.signals import ( - m2m_changed, post_save, post_delete + post_save, post_delete, m2m_changed ) from django.dispatch import receiver -from django.utils.functional import LazyObject +from django.utils.functional import lazy -from assets.models import Asset, Node +from assets.models import Node, Asset +from common.decorators import merge_delay_run from common.signals import django_ready from common.utils import get_logger from common.utils.connection import RedisPubSub from orgs.models import Organization -logger = get_logger(__file__) - +logger = get_logger(__name__) # clear node assets mapping for memory # ------------------------------------ +node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)() -class NodeAssetsMappingForMemoryPubSub(LazyObject): - def _setup(self): - self._wrapped = RedisPubSub('fm.node_all_asset_ids_memory_mapping') - - -node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub() - - -def expire_node_assets_mapping_for_memory(org_id): +@merge_delay_run(ttl=5) +def expire_node_assets_mapping(*org_ids): # 所有进程清除(自己的 memory 数据) - org_id = str(org_id) root_org_id = Organization.ROOT_ID - - # 当前进程清除(cache 数据) - Node.expire_node_all_asset_ids_mapping_from_cache(org_id) - Node.expire_node_all_asset_ids_mapping_from_cache(root_org_id) - - node_assets_mapping_for_memory_pub_sub.publish(org_id) + Node.expire_node_all_asset_ids_cache_mapping(root_org_id) + for org_id in set(org_ids): + org_id = str(org_id) + # 当前进程清除(cache 数据) + Node.expire_node_all_asset_ids_cache_mapping(org_id) + node_assets_mapping_pub_sub.publish(org_id) @receiver(post_save, sender=Node) @@ -50,17 +43,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs): need_expire = False if need_expire: - expire_node_assets_mapping_for_memory(instance.org_id) + expire_node_assets_mapping(instance.org_id) @receiver(post_delete, sender=Node) def on_node_post_delete(sender, instance, **kwargs): - expire_node_assets_mapping_for_memory(instance.org_id) + expire_node_assets_mapping(instance.org_id) @receiver(m2m_changed, sender=Asset.nodes.through) def on_node_asset_change(sender, instance, **kwargs): - expire_node_assets_mapping_for_memory(instance.org_id) + logger.debug("Recv asset nodes changed signal, expire memery node asset mapping") + expire_node_assets_mapping(instance.org_id) @receiver(django_ready) @@ -69,7 +63,7 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs): def handle_node_relation_change(org_id): root_org_id = Organization.ROOT_ID - Node.expire_node_all_asset_ids_mapping_from_memory(org_id) - Node.expire_node_all_asset_ids_mapping_from_memory(root_org_id) + Node.expire_node_all_asset_ids_memory_mapping(org_id) + Node.expire_node_all_asset_ids_memory_mapping(root_org_id) - node_assets_mapping_for_memory_pub_sub.subscribe(handle_node_relation_change) + node_assets_mapping_pub_sub.subscribe(handle_node_relation_change) diff --git a/apps/assets/tasks/common.py b/apps/assets/tasks/common.py index a00485fe5..1931347b1 100644 --- a/apps/assets/tasks/common.py +++ b/apps/assets/tasks/common.py @@ -8,8 +8,8 @@ from common.const.choices import Trigger from orgs.utils import current_org -def generate_automation_execution_data(task_name, tp, child_snapshot=None): - child_snapshot = child_snapshot or {} +def generate_automation_execution_data(task_name, tp, task_snapshot=None): + task_snapshot = task_snapshot or {} from assets.models import BaseAutomation try: eid = current_task.request.id @@ -25,13 +25,13 @@ def generate_automation_execution_data(task_name, tp, child_snapshot=None): automation_instance = BaseAutomation() snapshot = automation_instance.to_attr_json() snapshot.update(data) - snapshot.update(child_snapshot) + snapshot.update(task_snapshot) return {'id': eid, 'snapshot': snapshot} -def quickstart_automation(task_name, tp, child_snapshot=None): +def quickstart_automation(task_name, tp, task_snapshot=None): from assets.models import AutomationExecution - data = generate_automation_execution_data(task_name, tp, child_snapshot) + data = generate_automation_execution_data(task_name, tp, task_snapshot) while True: try: diff --git a/apps/assets/tasks/gather_facts.py b/apps/assets/tasks/gather_facts.py index 835621568..764342087 100644 --- a/apps/assets/tasks/gather_facts.py +++ b/apps/assets/tasks/gather_facts.py @@ -1,65 +1,55 @@ # -*- coding: utf-8 -*- # +from itertools import chain + from celery import shared_task from django.utils.translation import gettext_noop, gettext_lazy as _ from assets.const import AutomationTypes from common.utils import get_logger -from orgs.utils import org_aware_func +from orgs.utils import tmp_to_org from .common import quickstart_automation logger = get_logger(__file__) __all__ = [ - 'update_assets_fact_util', + 'gather_assets_facts_task', 'update_node_assets_hardware_info_manual', 'update_assets_hardware_info_manual', ] -def update_fact_util(assets=None, nodes=None, task_name=None): +@shared_task(queue="ansible", verbose_name=_('Gather assets facts')) +def gather_assets_facts_task(assets=None, nodes=None, task_name=None): from assets.models import GatherFactsAutomation if task_name is None: - task_name = gettext_noop("Update some assets hardware info. ") + task_name = gettext_noop("Gather assets facts") task_name = GatherFactsAutomation.generate_unique_name(task_name) nodes = nodes or [] assets = assets or [] - child_snapshot = { + resources = chain(assets, nodes) + if not resources: + raise ValueError("nodes or assets must be given") + org_id = list(resources)[0].org_id + task_snapshot = { 'assets': [str(asset.id) for asset in assets], 'nodes': [str(node.id) for node in nodes], } tp = AutomationTypes.gather_facts - quickstart_automation(task_name, tp, child_snapshot) + + with tmp_to_org(org_id): + quickstart_automation(task_name, tp, task_snapshot) -@org_aware_func('assets') -def update_assets_fact_util(assets=None, task_name=None): - if assets is None: - logger.info("No assets to update hardware info") - return - - update_fact_util(assets=assets, task_name=task_name) - - -@org_aware_func('nodes') -def update_nodes_fact_util(nodes=None, task_name=None): - if nodes is None: - logger.info("No nodes to update hardware info") - return - update_fact_util(nodes=nodes, task_name=task_name) - - -@shared_task(queue="ansible", verbose_name=_('Manually update the hardware information of assets')) def update_assets_hardware_info_manual(asset_ids): from assets.models import Asset assets = Asset.objects.filter(id__in=asset_ids) task_name = gettext_noop("Update assets hardware info: ") - update_assets_fact_util(assets=assets, task_name=task_name) + gather_assets_facts_task.delay(assets=assets, task_name=task_name) -@shared_task(queue="ansible", verbose_name=_('Manually update the hardware information of assets under a node')) def update_node_assets_hardware_info_manual(node_id): from assets.models import Node node = Node.objects.get(id=node_id) task_name = gettext_noop("Update node asset hardware information: ") - update_nodes_fact_util(nodes=[node], task_name=task_name) + gather_assets_facts_task.delay(nodes=[node], task_name=task_name) diff --git a/apps/assets/tasks/ping.py b/apps/assets/tasks/ping.py index dd48af59d..5cf0a94f8 100644 --- a/apps/assets/tasks/ping.py +++ b/apps/assets/tasks/ping.py @@ -24,8 +24,8 @@ def test_assets_connectivity_task(assets, task_name=None): task_name = gettext_noop("Test assets connectivity ") task_name = PingAutomation.generate_unique_name(task_name) - child_snapshot = {'assets': [str(asset.id) for asset in assets]} - quickstart_automation(task_name, AutomationTypes.ping, child_snapshot) + task_snapshot = {'assets': [str(asset.id) for asset in assets]} + quickstart_automation(task_name, AutomationTypes.ping, task_snapshot) def test_assets_connectivity_manual(asset_ids): diff --git a/apps/assets/utils/node.py b/apps/assets/utils/node.py index deaa7d32b..b21508b64 100644 --- a/apps/assets/utils/node.py +++ b/apps/assets/utils/node.py @@ -1,12 +1,12 @@ # ~*~ coding: utf-8 ~*~ # from collections import defaultdict + +from common.db.models import output_as_string +from common.struct import Stack from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit from common.utils.http import is_true -from common.struct import Stack -from common.db.models import output_as_string from orgs.utils import ensure_in_real_or_default_org, current_org - from ..locks import NodeTreeUpdateLock from ..models import Node, Asset @@ -25,11 +25,11 @@ def check_node_assets_amount(): for node in nodes: nodeid_nodekey_mapper[node.id] = node.key - for nodeid, assetid in nodeid_assetid_pairs: - if nodeid not in nodeid_nodekey_mapper: + for node_id, asset_id in nodeid_assetid_pairs: + if node_id not in nodeid_nodekey_mapper: continue - nodekey = nodeid_nodekey_mapper[nodeid] - nodekey_assetids_mapper[nodekey].add(assetid) + node_key = nodeid_nodekey_mapper[node_id] + nodekey_assetids_mapper[node_key].add(asset_id) util = NodeAssetsUtil(nodes, nodekey_assetids_mapper) util.generate() diff --git a/apps/audits/signal_handlers.py b/apps/audits/signal_handlers.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/common/const/__init__.py b/apps/common/const/__init__.py index a6bcd95c8..ca2a510d5 100644 --- a/apps/common/const/__init__.py +++ b/apps/common/const/__init__.py @@ -1,14 +1,8 @@ # -*- coding: utf-8 -*- # -from django.utils.translation import ugettext_lazy as _ - -create_success_msg = _("%(name)s was created successfully") -update_success_msg = _("%(name)s was updated successfully") -FILE_END_GUARD = ">>> Content End <<<" -celery_task_pre_key = "CELERY_" -KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}" - -# AD User AccountDisable -# https://docs.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties -LDAP_AD_ACCOUNT_DISABLE = 2 +from .choices import * +from .common import * +from .crontab import * +from .http import * +from .signals import * diff --git a/apps/common/const/common.py b/apps/common/const/common.py new file mode 100644 index 000000000..86ae030d9 --- /dev/null +++ b/apps/common/const/common.py @@ -0,0 +1,11 @@ +from django.utils.translation import ugettext_lazy as _ + +create_success_msg = _("%(name)s was created successfully") +update_success_msg = _("%(name)s was updated successfully") +FILE_END_GUARD = ">>> Content End <<<" +celery_task_pre_key = "CELERY_" +KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}" + +# AD User AccountDisable +# https://docs.microsoft.com/en-us/troubleshoot/windows-server/identity/useraccountcontrol-manipulate-account-properties +LDAP_AD_ACCOUNT_DISABLE = 2 diff --git a/apps/common/decorators.py b/apps/common/decorators.py index 91a6ca623..7e2c41ec5 100644 --- a/apps/common/decorators.py +++ b/apps/common/decorators.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- # +import asyncio import functools import inspect +import threading import time import uuid from concurrent.futures import ThreadPoolExecutor @@ -9,6 +11,8 @@ from concurrent.futures import ThreadPoolExecutor from django.core.cache import cache from django.db import transaction +from .utils import logger + def on_transaction_commit(func): """ @@ -34,54 +38,124 @@ class Singleton(object): return self._instance[self._cls] -def _run_func_if_is_last(ttl, func, *args, **kwargs): - ix = uuid.uuid4().__str__() - key = f'DELAY_RUN_{func.__name__}' - cache.set(key, ix, ttl) - st = (ttl - 2 > 1) and ttl - 2 or 2 - time.sleep(st) - got = cache.get(key, None) +def default_suffix_key(*args, **kwargs): + return 'default' - if ix == got: + +def key_by_org(*args, **kwargs): + return args[0].org_id + + +def _run_func_if_is_last(ttl, suffix_key, org, func, *args, **kwargs): + from orgs.utils import set_current_org + + try: + set_current_org(org) + uid = uuid.uuid4().__str__() + suffix_key_func = suffix_key if suffix_key else default_suffix_key + func_name = f'{func.__module__}_{func.__name__}' + key_suffix = suffix_key_func(*args, **kwargs) + key = f'DELAY_RUN_{func_name}_{key_suffix}' + cache.set(key, uid, ttl) + st = (ttl - 2 > 1) and ttl - 2 or 2 + time.sleep(st) + ret = cache.get(key, None) + + if uid == ret: + func(*args, **kwargs) + except Exception as e: + logger.error('delay run error: %s' % e) + + +class LoopThread(threading.Thread): + def __init__(self, loop, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loop = loop + + def run(self) -> None: + asyncio.set_event_loop(loop) + self.loop.run_forever() + print('loop stopped') + + +loop = asyncio.get_event_loop() +loop_thread = LoopThread(loop) +loop_thread.daemon = True +loop_thread.start() +executor = ThreadPoolExecutor(max_workers=5, thread_name_prefix='debouncer') + + +class Debouncer(object): + def __init__(self, callback, check, delay, *args, **kwargs): + self.callback = callback + self.check = check + self.delay = delay + + async def __call__(self, *args, **kwargs): + await asyncio.sleep(self.delay) + ok = await self._check(*args, **kwargs) + if ok: + await loop.run_in_executor(executor, self.callback, *args) + + async def _check(self, *args, **kwargs): + if asyncio.iscoroutinefunction(self.check): + return await self.check(*args, **kwargs) + return await loop.run_in_executor(executor, self.check) + + +def _run_func_with_org(org, func, *args, **kwargs): + from orgs.utils import set_current_org + + try: + set_current_org(org) func(*args, **kwargs) + except Exception as e: + logger.error('delay run error: %s' % e) -executor = ThreadPoolExecutor(10) +def delay_run(ttl=5, key=None): + """ + 延迟执行函数, 在 ttl 秒内, 只执行最后一次 + :param ttl: + :param key: 是否合并参数, 一个 callback + :return: + """ - -def delay_run(ttl=5): def inner(func): @functools.wraps(func) def wrapper(*args, **kwargs): - executor.submit(_run_func_if_is_last, ttl, func, *args, **kwargs) + from orgs.utils import get_current_org + org = get_current_org() + suffix_key_func = key if key else default_suffix_key + uid = uuid.uuid4().__str__() + func_name = f'{func.__module__}_{func.__name__}' + key_suffix = suffix_key_func(*args, **kwargs) + cache_key = f'DELAY_RUN_{func_name}_{key_suffix}' + # 延迟两倍时间,防止缓存过期,导致校验失败 + cache.set(cache_key, uid, ttl * 2) + + def _check_func(key_id, key_value): + ret = cache.get(key_id, None) + return key_value == ret + + check_func_partial = functools.partial(_check_func, cache_key, uid) + run_func_partial = functools.partial(_run_func_with_org, org, func) + asyncio.run_coroutine_threadsafe( + Debouncer(run_func_partial, check_func_partial, ttl)(*args, **kwargs), + loop=loop + ) return wrapper return inner -def _merge_run(ttl, func, *args, **kwargs): - if not args or not isinstance(args[0], (list, tuple)): - raise ValueError('args[0] must be list or tuple') - - key = f'DELAY_MERGE_RUN_{func.__name__}' - ix = uuid.uuid4().__str__() - value = cache.get(key, []) - value.extend(args[0]) - - st = (ttl - 2 > 1) and ttl - 2 or 2 - time.sleep(st) - got = cache.get(key, None) - - if ix == got: - func(*args, **kwargs) - - -def merge_delay_run(ttl): +def merge_delay_run(ttl, key=None): """ 合并 func 参数,延迟执行, 在 ttl 秒内, 只执行最后一次 func 参数必须是 *args :param ttl: + :param key: 是否合并参数, 一个 callback :return: """ @@ -93,42 +167,50 @@ def merge_delay_run(ttl): if not str(param).startswith('*'): raise ValueError('func args must be startswith *: %s' % func.__name__) + suffix_key_func = key if key else default_suffix_key + @functools.wraps(func) def wrapper(*args): - key = f'DELAY_MERGE_RUN_{func.__name__}' - values = cache.get(key, []) + key_suffix = suffix_key_func(*args) + func_name = f'{func.__module__}_{func.__name__}' + cache_key = f'DELAY_MERGE_RUN_{func_name}_{key_suffix}' + values = cache.get(cache_key, []) new_arg = [*values, *args] - cache.set(key, new_arg, ttl) - return delay_run(ttl)(func)(*new_arg) + cache.set(cache_key, new_arg, ttl) + return delay_run(ttl, suffix_key_func)(func)(*new_arg) return wrapper return inner -def delay_run(ttl=5): - """ - 延迟执行函数, 在 ttl 秒内, 只执行最后一次 - :param ttl: - :return: - """ - - def inner(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - executor.submit(_run_func_if_is_last, ttl, func, *args, **kwargs) - - return wrapper - - return inner - - -@delay_run(ttl=10) +@delay_run(ttl=5) def test_delay_run(username, year=2000): print("Hello, %s, now is %s" % (username, year)) -@merge_delay_run(ttl=10) +@merge_delay_run(ttl=5, key=lambda *users: users[0][0]) def test_merge_delay_run(*users): name = ','.join(users) + time.sleep(2) print("Hello, %s, now is %s" % (name, time.time())) + + +@merge_delay_run(ttl=5, key=lambda *users: users[0][0]) +def test_merge_delay_run(*users): + name = ','.join(users) + time.sleep(2) + print("Hello, %s, now is %s" % (name, time.time())) + + +def do_test(): + s = time.time() + print("start : %s" % time.time()) + for i in range(100): + # test_delay_run('test', year=i) + test_merge_delay_run('test %s' % i) + test_merge_delay_run('best %s' % i) + + end = time.time() + using = end - s + print("end : %s, using: %s" % (end, using)) diff --git a/apps/common/management/commands/expire_caches.py b/apps/common/management/commands/expire_caches.py index bc20a3cf5..d83b995b0 100644 --- a/apps/common/management/commands/expire_caches.py +++ b/apps/common/management/commands/expire_caches.py @@ -1,6 +1,6 @@ from django.core.management.base import BaseCommand -from assets.signal_handlers.node_assets_mapping import expire_node_assets_mapping_for_memory +from assets.signal_handlers.node_assets_mapping import expire_node_assets_mapping from orgs.caches import OrgResourceStatisticsCache from orgs.models import Organization @@ -10,7 +10,7 @@ def expire_node_assets_mapping(): org_ids = [*org_ids, '00000000-0000-0000-0000-000000000000'] for org_id in org_ids: - expire_node_assets_mapping_for_memory(org_id) + expire_node_assets_mapping(org_id) def expire_org_resource_statistics_cache(): diff --git a/apps/common/signal_handlers.py b/apps/common/signal_handlers.py index c65adf296..d03c76798 100644 --- a/apps/common/signal_handlers.py +++ b/apps/common/signal_handlers.py @@ -60,16 +60,18 @@ def on_request_finished_logging_db_query(sender, **kwargs): method = current_request.method path = current_request.get_full_path() - # print(">>> [{}] {}".format(method, path)) - # for table_name, queries in table_queries.items(): - # if table_name.startswith('rbac_') or table_name.startswith('auth_permission'): - # continue - # print("- Table: {}".format(table_name)) - # for i, query in enumerate(queries, 1): - # sql = query['sql'] - # if not sql or not sql.startswith('SELECT'): - # continue - # print('\t{}. {}'.format(i, sql)) + print(">>> [{}] {}".format(method, path)) + for table_name, queries in table_queries.items(): + if table_name.startswith('rbac_') or table_name.startswith('auth_permission'): + continue + if len(queries) < 3: + continue + print("- Table: {}".format(table_name)) + for i, query in enumerate(queries, 1): + sql = query['sql'] + if not sql or not sql.startswith('SELECT'): + continue + print('\t{}. {}'.format(i, sql)) logger.debug(">>> [{}] {}".format(method, path)) for name, counter in counters: diff --git a/apps/orgs/signal_handlers/cache.py b/apps/orgs/signal_handlers/cache.py index c702dc6d2..bf166d9a8 100644 --- a/apps/orgs/signal_handlers/cache.py +++ b/apps/orgs/signal_handlers/cache.py @@ -1,19 +1,19 @@ from django.db.models.signals import post_save, pre_delete, pre_save, post_delete from django.dispatch import receiver -from orgs.models import Organization -from assets.models import Node from accounts.models import Account +from assets.models import Asset, Domain +from assets.models import Node +from common.decorators import merge_delay_run +from common.utils import get_logger +from orgs.caches import OrgResourceStatisticsCache +from orgs.models import Organization +from orgs.utils import current_org from perms.models import AssetPermission -from audits.models import UserLoginLog +from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding +from terminal.models import Session from users.models import UserGroup, User from users.signals import pre_user_leave_org -from terminal.models import Session -from rbac.models import OrgRoleBinding, SystemRoleBinding, RoleBinding -from assets.models import Asset, Domain -from orgs.caches import OrgResourceStatisticsCache -from orgs.utils import current_org -from common.utils import get_logger logger = get_logger(__name__) @@ -62,42 +62,32 @@ def on_user_delete_refresh_cache(sender, instance, **kwargs): refresh_all_orgs_user_amount_cache(instance) -# @receiver(m2m_changed, sender=OrganizationMember) -# def on_org_user_changed_refresh_cache(sender, action, instance, reverse, pk_set, **kwargs): -# if not action.startswith(POST_PREFIX): -# return -# -# if reverse: -# orgs = Organization.objects.filter(id__in=pk_set) -# else: -# orgs = [instance] -# -# for org in orgs: -# org_cache = OrgResourceStatisticsCache(org) -# org_cache.expire('users_amount') -# OrgResourceStatisticsCache(Organization.root()).expire('users_amount') +model_cache_field_mapper = { + Node: ['nodes_amount'], + Domain: ['domains_amount'], + UserGroup: ['groups_amount'], + Account: ['accounts_amount'], + RoleBinding: ['users_amount', 'new_users_amount_this_week'], + Asset: ['assets_amount', 'new_assets_amount_this_week'], + AssetPermission: ['asset_perms_amount'], +} class OrgResourceStatisticsRefreshUtil: - model_cache_field_mapper = { - Node: ['nodes_amount'], - Domain: ['domains_amount'], - UserGroup: ['groups_amount'], - Account: ['accounts_amount'], - RoleBinding: ['users_amount', 'new_users_amount_this_week'], - Asset: ['assets_amount', 'new_assets_amount_this_week'], - AssetPermission: ['asset_perms_amount'], - - } + @staticmethod + @merge_delay_run(ttl=5) + def refresh_org_fields(*org_fields): + for org, cache_field_name in org_fields: + OrgResourceStatisticsCache(org).expire(*cache_field_name) + OrgResourceStatisticsCache(Organization.root()).expire(*cache_field_name) @classmethod def refresh_if_need(cls, instance): - cache_field_name = cls.model_cache_field_mapper.get(type(instance)) + cache_field_name = model_cache_field_mapper.get(type(instance)) if not cache_field_name: return - OrgResourceStatisticsCache(Organization.root()).expire(*cache_field_name) - if getattr(instance, 'org', None): - OrgResourceStatisticsCache(instance.org).expire(*cache_field_name) + org = getattr(instance, 'org', None) + cls.refresh_org_fields((org, cache_field_name)) @receiver(post_save) diff --git a/apps/settings/api/public.py b/apps/settings/api/public.py index 161b8002a..ab313cd4f 100644 --- a/apps/settings/api/public.py +++ b/apps/settings/api/public.py @@ -1,13 +1,10 @@ -from rest_framework import generics -from rest_framework.permissions import AllowAny, IsAuthenticated from django.conf import settings +from rest_framework import generics +from rest_framework.permissions import AllowAny -from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info -from common.utils import get_logger, lazyproperty, get_object_or_none -from authentication.models import ConnectionToken -from orgs.utils import tmp_to_root_org from common.permissions import IsValidUserOrConnectionToken - +from common.utils import get_logger, lazyproperty +from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info from .. import serializers from ..utils import get_interface_setting_or_default @@ -58,6 +55,3 @@ class PublicSettingApi(OpenPublicSettingApi): # 提前把异常爆出来 values[name] = getattr(settings, name) return values - - - diff --git a/apps/terminal/signal_handlers/db_port.py b/apps/terminal/signal_handlers/db_port.py index bacbc1f0c..d6118d40e 100644 --- a/apps/terminal/signal_handlers/db_port.py +++ b/apps/terminal/signal_handlers/db_port.py @@ -1,4 +1,4 @@ -from django.db.models.signals import post_save, post_delete +from django.db.models.signals import post_delete, post_save from django.dispatch import receiver from assets.models import Asset diff --git a/utils/generate_fake_data/generate.py b/utils/generate_fake_data/generate.py index ee9a5cac4..394e27113 100644 --- a/utils/generate_fake_data/generate.py +++ b/utils/generate_fake_data/generate.py @@ -50,7 +50,7 @@ def main(): if resource == 'all': generator_cls = resource_generator_mapper.values() else: - generator_cls.push(resource_generator_mapper[resource]) + generator_cls.append(resource_generator_mapper[resource]) for _cls in generator_cls: generator = _cls(org_id=org_id, batch_size=batch_size) diff --git a/utils/generate_fake_data/resources/base.py b/utils/generate_fake_data/resources/base.py index 723baa6dd..828789dd7 100644 --- a/utils/generate_fake_data/resources/base.py +++ b/utils/generate_fake_data/resources/base.py @@ -1,7 +1,8 @@ #!/usr/bin/python -from random import seed +import time from itertools import islice +from random import seed from orgs.models import Organization @@ -18,7 +19,6 @@ class FakeDataGenerator: o = Organization.get_instance(org_id, default=Organization.default()) if o: o.change_to() - print('Current org is: {}'.format(o)) return o def do_generate(self, batch, batch_size): @@ -38,8 +38,11 @@ class FakeDataGenerator: batch = list(islice(counter, self.batch_size)) if not batch: break + start = time.time() self.do_generate(batch, self.batch_size) + end = time.time() + using = end - start from_size = created created += len(batch) - print('Generate %s: %s-%s' % (self.resource, from_size, created)) - self.after_generate() \ No newline at end of file + print('Generate %s: %s-%s [{}s]' % (self.resource, from_size, created, using)) + self.after_generate()