diff --git a/apps/accounts/api/automations/gather_account.py b/apps/accounts/api/automations/gather_account.py index d7d8bb54f..c015116ce 100644 --- a/apps/accounts/api/automations/gather_account.py +++ b/apps/accounts/api/automations/gather_account.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # +from django.db import transaction from django.shortcuts import get_object_or_404 from rest_framework import status from rest_framework.decorators import action @@ -82,7 +83,8 @@ class GatheredAccountViewSet(OrgBulkModelViewSet): 'is_sync_account': False, 'name': 'Adhoc gather accounts: {}'.format(asset_id), } - execution.save() + with transaction.atomic(): + execution.save() execution.start() accounts = self.model.objects.filter(asset=asset) return self.get_paginated_response_from_queryset(accounts) diff --git a/apps/accounts/automations/base/manager.py b/apps/accounts/automations/base/manager.py index 652a8aa89..44a2c06c4 100644 --- a/apps/accounts/automations/base/manager.py +++ b/apps/accounts/automations/base/manager.py @@ -13,11 +13,3 @@ class AccountBasePlaybookManager(BasePlaybookManager): @property def platform_automation_methods(self): return platform_automation_methods - - def gen_report(self): - context = { - 'execution': self.execution, - 'summary': self.execution.summary, - 'result': self.execution.result - } - return render_to_string(self.template_path, context) diff --git a/apps/accounts/automations/change_secret/manager.py b/apps/accounts/automations/change_secret/manager.py index 294826802..8b083f917 100644 --- a/apps/accounts/automations/change_secret/manager.py +++ b/apps/accounts/automations/change_secret/manager.py @@ -7,11 +7,15 @@ from django.utils import timezone from django.utils.translation import gettext_lazy as _ from xlsxwriter import Workbook -from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy, ChangeSecretRecordStatusChoice +from accounts.const import ( + AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy, ChangeSecretRecordStatusChoice +) from accounts.models import ChangeSecretRecord, BaseAccountQuerySet from accounts.notifications import ChangeSecretExecutionTaskMsg, ChangeSecretFailedMsg from accounts.serializers import ChangeSecretRecordBackUpSerializer from assets.const import HostTypes +from common.db.utils import safe_db_connection +from common.decorators import bulk_create_decorator from common.utils import get_logger from common.utils.file import encrypt_and_compress_zip_file from common.utils.timezone import local_now_filename @@ -26,7 +30,7 @@ class ChangeSecretManager(AccountBasePlaybookManager): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.record_map = self.execution.snapshot.get('record_map', {}) + self.record_map = self.execution.snapshot.get('record_map', {}) # 这个是某个失败的记录重试 self.secret_type = self.execution.snapshot.get('secret_type') self.secret_strategy = self.execution.snapshot.get( 'secret_strategy', SecretStrategy.custom @@ -36,6 +40,7 @@ class ChangeSecretManager(AccountBasePlaybookManager): ) self.account_ids = self.execution.snapshot['accounts'] self.name_recorder_mapper = {} # 做个映射,方便后面处理 + self.pending_add_records = [] @classmethod def method_type(cls): @@ -52,18 +57,6 @@ class ChangeSecretManager(AccountBasePlaybookManager): kwargs['regexp'] = '.*{}$'.format(secret.split()[2].strip()) return kwargs - def secret_generator(self, secret_type): - return SecretGenerator( - self.secret_strategy, secret_type, - self.execution.snapshot.get('password_rules') - ) - - def get_secret(self, secret_type): - if self.secret_strategy == SecretStrategy.custom: - return self.execution.snapshot['secret'] - else: - return self.secret_generator(secret_type).get_secret() - def get_accounts(self, privilege_account) -> BaseAccountQuerySet | None: if not privilege_account: print('Not privilege account') @@ -81,10 +74,65 @@ class ChangeSecretManager(AccountBasePlaybookManager): ) return accounts - def host_callback( - self, host, asset=None, account=None, - automation=None, path_dir=None, **kwargs - ): + def gen_new_secret(self, account, path_dir): + private_key_path = None + if self.secret_type is None: + new_secret = account.secret + return new_secret, private_key_path + + if self.secret_strategy == SecretStrategy.custom: + new_secret = self.execution.snapshot['secret'] + else: + generator = SecretGenerator( + self.secret_strategy, self.secret_type, + self.execution.snapshot.get('password_rules') + ) + new_secret = generator.get_secret() + + if account.secret_type == SecretType.SSH_KEY: + private_key_path = self.generate_private_key_path(new_secret, path_dir) + new_secret = self.generate_public_key(new_secret) + return new_secret, private_key_path + + def get_or_create_record(self, asset, account, new_secret, name): + asset_account_id = f'{asset.id}-{account.id}' + + if asset_account_id in self.record_map: + record_id = self.record_map[asset_account_id] + recorder = ChangeSecretRecord.objects.filter(id=record_id).first() + else: + recorder = self.create_record(asset, account, new_secret) + + if recorder: + self.name_recorder_mapper[name] = recorder + + @bulk_create_decorator(ChangeSecretRecord) + def create_record(self, asset, account, new_secret): + recorder = ChangeSecretRecord( + asset=asset, account=account, execution=self.execution, + old_secret=account.secret, new_secret=new_secret, + comment=f'{account.username}@{asset.address}' + ) + return recorder + + def gen_change_secret_inventory(self, host, account, new_secret, private_key_path, asset): + h = deepcopy(host) + secret_type = account.secret_type + h['name'] += '(' + account.username + ')' + h['ssh_params'].update(self.get_ssh_params(account, new_secret, secret_type)) + h['account'] = { + 'name': account.name, + 'username': account.username, + 'secret_type': secret_type, + 'secret': account.escape_jinja2_syntax(new_secret), + 'private_key_path': private_key_path, + 'become': account.get_ansible_become_auth(), + } + if asset.platform.type == 'oracle': + h['account']['mode'] = 'sysdba' if account.privileged else None + return h + + def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs): host = super().host_callback( host, asset=asset, account=account, automation=automation, path_dir=path_dir, **kwargs @@ -93,72 +141,29 @@ class ChangeSecretManager(AccountBasePlaybookManager): return host host['check_conn_after_change'] = self.execution.snapshot.get('check_conn_after_change', True) + host['ssh_params'] = {} accounts = self.get_accounts(account) - error_msg = _("No pending accounts found") + error_msg = _("! No pending accounts found") if not accounts: print(f'{asset}: {error_msg}') return [] - records = [] + if asset.type == HostTypes.WINDOWS: + accounts = accounts.filter(secret_type=SecretType.PASSWORD) + inventory_hosts = [] if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY: - print(f'Windows {asset} does not support ssh key push') + print(f'! Windows {asset} does not support ssh key push') return inventory_hosts - if asset.type == HostTypes.WINDOWS: - accounts = accounts.filter(secret_type=SecretType.PASSWORD) - - host['ssh_params'] = {} for account in accounts: - h = deepcopy(host) - secret_type = account.secret_type - h['name'] += '(' + account.username + ')' - if self.secret_type is None: - new_secret = account.secret - else: - new_secret = self.get_secret(secret_type) - - if new_secret is None: - print(f'new_secret is None, account: {account}') - continue - - asset_account_id = f'{asset.id}-{account.id}' - if asset_account_id not in self.record_map: - recorder = ChangeSecretRecord( - asset=asset, account=account, execution=self.execution, - old_secret=account.secret, new_secret=new_secret, - comment=f'{account.username}@{asset.address}' - ) - records.append(recorder) - else: - record_id = self.record_map[asset_account_id] - try: - recorder = ChangeSecretRecord.objects.get(id=record_id) - except ChangeSecretRecord.DoesNotExist: - print(f"Record {record_id} not found") - continue - - self.name_recorder_mapper[h['name']] = recorder - - private_key_path = None - if secret_type == SecretType.SSH_KEY: - private_key_path = self.generate_private_key_path(new_secret, path_dir) - new_secret = self.generate_public_key(new_secret) - - h['ssh_params'].update(self.get_ssh_params(account, new_secret, secret_type)) - h['account'] = { - 'name': account.name, - 'username': account.username, - 'secret_type': secret_type, - 'secret': account.escape_jinja2_syntax(new_secret), - 'private_key_path': private_key_path, - 'become': account.get_ansible_become_auth(), - } - if asset.platform.type == 'oracle': - h['account']['mode'] = 'sysdba' if account.privileged else None + new_secret, private_key_path = self.gen_new_secret(account, path_dir) + h = self.gen_change_secret_inventory(host, account, new_secret, private_key_path, asset) + self.get_or_create_record(asset, account, new_secret, h['name']) inventory_hosts.append(h) - ChangeSecretRecord.objects.bulk_create(records) + + self.create_record.finish() return inventory_hosts def on_host_success(self, host, result): @@ -172,24 +177,13 @@ class ChangeSecretManager(AccountBasePlaybookManager): if not account: print("Account not found, deleted ?") return + account.secret = recorder.new_secret account.date_updated = timezone.now() - max_retries = 3 - retry_count = 0 - - while retry_count < max_retries: - try: - recorder.save() - account.save(update_fields=['secret', 'version', 'date_updated']) - break - except Exception as e: - retry_count += 1 - if retry_count == max_retries: - self.on_host_error(host, str(e), result) - else: - print(f'retry {retry_count} times for {host} recorder save error: {e}') - time.sleep(1) + with safe_db_connection(): + recorder.save(update_fields=['status', 'date_finished']) + account.save(update_fields=['secret', 'version', 'date_updated']) def on_host_error(self, host, error, result): recorder = self.name_recorder_mapper.get(host) @@ -222,21 +216,24 @@ class ChangeSecretManager(AccountBasePlaybookManager): else: failed += 1 total += 1 - summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total) return summary - def run(self, *args, **kwargs): - if self.secret_type and not self.check_secret(): - self.execution.status = 'success' - self.execution.date_finished = timezone.now() - self.execution.save() - return - super().run(*args, **kwargs) + def print_summary(self): recorders = list(self.name_recorder_mapper.values()) summary = self.get_summary(recorders) - print(summary, end='') + print('\n\n' + '-' * 80) + plan_execution_end = _('Plan execution end') + print('{} {}\n'.format(plan_execution_end, local_now_filename())) + time_cost = _('Time cost') + print('{}: {}s'.format(time_cost, self.duration)) + print(summary) + + def send_report_if_need(self, *args, **kwargs): + if self.secret_type and not self.check_secret(): + return + recorders = list(self.name_recorder_mapper.values()) if self.record_map: return @@ -262,6 +259,7 @@ class ChangeSecretManager(AccountBasePlaybookManager): if not recorders: return + summary = self.get_summary(recorders) self.send_recorder_mail(recipients, recorders, summary) def send_recorder_mail(self, recipients, recorders, summary): diff --git a/apps/accounts/automations/check_account/manager.py b/apps/accounts/automations/check_account/manager.py index eb60057e7..011196e5a 100644 --- a/apps/accounts/automations/check_account/manager.py +++ b/apps/accounts/automations/check_account/manager.py @@ -88,7 +88,7 @@ def check_account_secrets(accounts, assets): class CheckAccountManager(BaseManager): - batch_size=100 + batch_size = 100 def __init__(self, execution): super().__init__(execution) @@ -117,8 +117,14 @@ class CheckAccountManager(BaseManager): for k, v in result.items(): self.result[k].extend(v) + def get_report_subject(self): + return "Check account report of %s" % self.execution.id + + def get_report_template(self): + return 'accounts/check_account_report.html' + def print_summary(self): tmpl = "\n---\nSummary: \nok: %s, weak password: %s, no secret: %s, using time: %ss" % ( - self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.timedelta) + self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.duration) ) print(tmpl) diff --git a/apps/accounts/automations/gather_account/manager.py b/apps/accounts/automations/gather_account/manager.py index 5fe5eee9c..a20dc4533 100644 --- a/apps/accounts/automations/gather_account/manager.py +++ b/apps/accounts/automations/gather_account/manager.py @@ -6,6 +6,7 @@ from accounts.const import AutomationTypes from accounts.models import GatheredAccount, Account, AccountRisk from assets.models import Asset from common.const import ConfirmOrIgnore +from common.decorators import bulk_create_decorator, bulk_update_decorator from common.utils import get_logger from common.utils.strings import get_text_diff from orgs.utils import tmp_to_org @@ -62,35 +63,18 @@ class AnalyseAccountRisk: if not diff: return + risks = [] for k, v in diff.items(): - self.pending_add_risks.append(dict( + risks.append(dict( asset=ori_account.asset, username=ori_account.username, risk=k+'_changed', detail={'diff': v} )) - - def perform_save_risks(self, risks): - # 提前取出来,避免每次都查数据库 - assets = {r['asset'] for r in risks} - assets_risks = AccountRisk.objects.filter(asset__in=assets) - assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks} - - for d in risks: - detail = d.pop('detail', {}) - detail['datetime'] = self.now.isoformat() - key = f"{d['asset'].id}_{d['username']}_{d['risk']}" - found = assets_risks.get(key) - - if not found: - r = AccountRisk(**d, details=[detail]) - r.save() - continue - - found.details.append(detail) - found.save(update_fields=['details']) + self.save_or_update_risks(risks) def _analyse_datetime_changed(self, ori_account, d, asset, username): basic = {'asset': asset, 'username': username} + risks = [] for item in self.datetime_check_items: field = item['field'] risk = item['risk'] @@ -105,41 +89,61 @@ class AnalyseAccountRisk: continue if date and date < timezone.now() - delta: - self.pending_add_risks.append( + risks.append( dict(**basic, risk=risk, detail={'date': date.isoformat()}) ) - def batch_analyse_risk(self, asset, ori_account, d, batch_size=20): - if not self.check_risk: - return + self.save_or_update_risks(risks) + + def save_or_update_risks(self, risks): + # 提前取出来,避免每次都查数据库 + assets = {r['asset'] for r in risks} + assets_risks = AccountRisk.objects.filter(asset__in=assets) + assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks} - if asset is None: - if self.pending_add_risks: - self.perform_save_risks(self.pending_add_risks) - self.pending_add_risks = [] + for d in risks: + detail = d.pop('detail', {}) + detail['datetime'] = self.now.isoformat() + key = f"{d['asset'].id}_{d['username']}_{d['risk']}" + found = assets_risks.get(key) + + if not found: + self._create_risk(dict(**d, details=[detail])) + continue + + found.details.append(detail) + self._update_risk(found) + + @bulk_create_decorator(AccountRisk) + def _create_risk(self, data): + return AccountRisk(**data) + + @bulk_update_decorator(AccountRisk, update_fields=['details']) + def _update_risk(self, account): + return account + + def finish(self): + self._create_risk.finish() + self._update_risk.finish() + + def analyse_risk(self, asset, ori_account, d): + if not self.check_risk: return basic = {'asset': asset, 'username': d['username']} if ori_account: self._analyse_item_changed(ori_account, d) else: - self.pending_add_risks.append( - dict(**basic, risk='ghost') - ) + self._create_risk(dict(**basic, risk='new_account')) self._analyse_datetime_changed(ori_account, d, asset, d['username']) - if len(self.pending_add_risks) > batch_size: - self.batch_analyse_risk(None, None, {}) - class GatherAccountsManager(AccountBasePlaybookManager): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.host_asset_mapper = {} self.asset_account_info = {} - self.pending_add_accounts = [] - self.pending_update_accounts = [] self.asset_usernames_mapper = defaultdict(set) self.ori_asset_usernames = defaultdict(set) self.ori_gathered_usernames = defaultdict(set) @@ -204,7 +208,7 @@ class GatherAccountsManager(AccountBasePlaybookManager): ga_accounts = GatheredAccount.objects.filter(asset__in=assets) for account in ga_accounts: self.ori_gathered_usernames[account.asset].add(account.username) - key = '{}_{}'.format(account.asset.id, account.username) + key = '{}_{}'.format(account.asset_id, account.username) self.ori_gathered_accounts_mapper[key] = account def update_gather_accounts_status(self, asset): @@ -258,38 +262,25 @@ class GatherAccountsManager(AccountBasePlaybookManager): # 资产上没有的,标识为为存在 queryset.exclude(username__in=ori_users).filter(present=False).update(present=True) - def batch_create_gathered_account(self, d, batch_size=20): - if d is None: - if self.pending_add_accounts: - GatheredAccount.objects.bulk_create(self.pending_add_accounts, ignore_conflicts=True) - self.pending_add_accounts = [] - return - + @bulk_create_decorator(GatheredAccount) + def create_gathered_account(self, d): gathered_account = GatheredAccount() for k, v in d.items(): setattr(gathered_account, k, v) - self.pending_add_accounts.append(gathered_account) - - if len(self.pending_add_accounts) > batch_size: - self.batch_create_gathered_account(None) - - def batch_update_gathered_account(self, ori_account, d, batch_size=20): - if not ori_account or d is None: - if self.pending_update_accounts: - GatheredAccount.objects.bulk_update(self.pending_update_accounts, [*diff_items]) - self.pending_update_accounts = [] - return + return gathered_account + @bulk_update_decorator(GatheredAccount, update_fields=diff_items) + def update_gathered_account(self, ori_account, d): diff = get_items_diff(ori_account, d) - if diff: - for k in diff: - setattr(ori_account, k, d[k]) - self.pending_update_accounts.append(ori_account) - - if len(self.pending_update_accounts) > batch_size: - self.batch_update_gathered_account(None, None) + if not diff: + return + for k in diff: + setattr(ori_account, k, d[k]) + return ori_account - def do_run(self): + def do_run(self, *args, **kwargs): + super().do_run(*args, **kwargs) + self.prefetch_origin_account_usernames() risk_analyser = AnalyseAccountRisk(self.check_risk) for asset, accounts_data in self.asset_account_info.items(): @@ -300,21 +291,20 @@ class GatherAccountsManager(AccountBasePlaybookManager): ori_account = self.ori_gathered_accounts_mapper.get('{}_{}'.format(asset.id, username)) if not ori_account: - self.batch_create_gathered_account(d) + self.create_gathered_account(d) else: - self.batch_update_gathered_account(ori_account, d) - risk_analyser.batch_analyse_risk(asset, ori_account, d) + self.update_gathered_account(ori_account, d) + risk_analyser.analyse_risk(asset, ori_account, d) self.update_gather_accounts_status(asset) GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account) - self.batch_create_gathered_account(None) - self.batch_update_gathered_account(None, None) - risk_analyser.batch_analyse_risk(None, None, {}) + self.create_gathered_account.finish() + self.update_gathered_account.finish() + risk_analyser.finish() - def before_run(self): - super().before_run() - self.prefetch_origin_account_usernames() + def send_report_if_need(self): + pass def generate_send_users_and_change_info(self): recipients = self.execution.recipients diff --git a/apps/assets/automations/base/manager.py b/apps/assets/automations/base/manager.py index e9f6f4985..58a9604cc 100644 --- a/apps/assets/automations/base/manager.py +++ b/apps/assets/automations/base/manager.py @@ -98,60 +98,63 @@ class BaseManager: def get_assets_group_by_platform(self): return self.execution.all_assets_group_by_platform() - def before_run(self): + def pre_run(self): self.execution.date_start = timezone.now() self.execution.save(update_fields=['date_start']) - def get_report_subject(self): - return f'Automation {self.execution.id} finished' - - def send_report_if_need(self): - recipients = self.execution.recipients - if not recipients: - return - - report = self.gen_report() - report = transform(report) - subject = self.get_report_subject() - print("Send resport to: {}".format([str(r) for r in recipients])) - emails = [r.email for r in recipients if r.email] - send_mail_async(subject, report, emails, html_message=report) - def update_execution(self): self.duration = int(time.time() - self.time_start) self.execution.date_finished = timezone.now() self.execution.duration = self.duration self.execution.summary = self.summary self.execution.result = self.result + self.execution.status = 'success' with safe_db_connection(): - self.execution.save(update_fields=['date_finished', 'duration', 'summary', 'result']) + self.execution.save() def print_summary(self): pass - def get_template_path(self): + def get_report_template(self): raise NotImplementedError - def gen_report(self): - template_path = self.get_template_path() - context = { + def get_report_subject(self): + return f'Automation {self.execution.id} finished' + + def get_report_context(self): + return { 'execution': self.execution, 'summary': self.execution.summary, 'result': self.execution.result } + + def send_report_if_need(self): + recipients = self.execution.recipients + if not recipients: + return + + report = self.gen_report() + report = transform(report) + subject = self.get_report_subject() + emails = [r.email for r in recipients if r.email] + send_mail_async(subject, report, emails, html_message=report) + + def gen_report(self): + template_path = self.get_report_template() + context = self.get_report_context() data = render_to_string(template_path, context) return data - def after_run(self): + def post_run(self): self.update_execution() self.print_summary() self.send_report_if_need() def run(self, *args, **kwargs): - self.before_run() + self.pre_run() self.do_run(*args, **kwargs) - self.after_run() + self.post_run() def do_run(self, *args, **kwargs): raise NotImplementedError @@ -374,15 +377,19 @@ class BasePlaybookManager(BaseManager): if settings.DEBUG_DEV: print('host error: {} -> {}'.format(host, error)) - def _on_host_success(self, host, result, error, detail): - self.on_host_success(host, result) + def _on_host_success(self, host, result, hosts): + self.on_host_success(host, result.get("ok", '')) - def _on_host_error(self, host, result, error, detail): + def _on_host_error(self, host, result, hosts): + error = hosts.get(host, '') + detail = result.get('failures', '') or result.get('dark', '') self.on_host_error(host, error, detail) def on_runner_success(self, runner, cb): summary = cb.summary for state, hosts in summary.items(): + # 错误行为为,host 是 dict, ok 时是 list + if state == 'ok': handler = self._on_host_success elif state == 'skipped': @@ -392,9 +399,7 @@ class BasePlaybookManager(BaseManager): for host in hosts: result = cb.host_results.get(host) - error = hosts.get(host, '') - detail = result.get('failures', '') or result.get('dark', '') - handler(host, result, error, detail) + handler(host, result, hosts) def on_runner_failed(self, runner, e): print("Runner failed: {} {}".format(e, self)) diff --git a/apps/common/db/utils.py b/apps/common/db/utils.py index e61db0bfd..4fa84dd7f 100644 --- a/apps/common/db/utils.py +++ b/apps/common/db/utils.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from django.db import connections, transaction +from django.db import connections, transaction, connection from django.utils.encoding import force_str from common.utils import get_logger, signer, crypto @@ -13,7 +13,7 @@ def get_object_if_need(model, pk): try: return model.objects.get(id=pk) except model.DoesNotExist as e: - logger.error(f'DoesNotExist: <{model.__name__}:{pk}> not exist') + logger.error(f"DoesNotExist: <{model.__name__}:{pk}> not exist") raise e return pk @@ -26,8 +26,8 @@ def get_objects_if_need(model, pks): if len(objs) != len(pks): pks = set(pks) exists_pks = {o.id for o in objs} - not_found_pks = ','.join(pks - exists_pks) - logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>') + not_found_pks = ",".join(pks - exists_pks) + logger.error(f"DoesNotExist: <{model.__name__}: {not_found_pks}>") return objs return pks @@ -41,7 +41,7 @@ def get_objects(model, pks): pks = set(pks) exists_pks = {o.id for o in objs} not_found_pks = pks - exists_pks - logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>') + logger.error(f"DoesNotExist: <{model.__name__}: {not_found_pks}>") return objs @@ -53,13 +53,17 @@ def close_old_connections(): @contextmanager def safe_db_connection(): - close_old_connections() - yield - close_old_connections() + try: + close_old_connections() # 确保旧连接关闭 + if connection.connection: # 如果连接已关闭,重新连接 + connection.connect() + yield + finally: + close_old_connections() # 确保最终关闭连接 @contextmanager -def open_db_connection(alias='default'): +def open_db_connection(alias="default"): connection = transaction.get_connection(alias) try: connection.connect() diff --git a/apps/common/decorators.py b/apps/common/decorators.py index 394b4ec15..d46f5489f 100644 --- a/apps/common/decorators.py +++ b/apps/common/decorators.py @@ -11,8 +11,8 @@ from functools import wraps from django.db import transaction -from .utils import logger from .db.utils import open_db_connection +from .utils import logger def on_transaction_commit(func): @@ -294,3 +294,104 @@ def cached_method(ttl=20): return wrapper return decorator + + +def bulk_create_decorator(instance_model, batch_size=50): + """ + 装饰器,用于将实例批量保存,并提供 `commit` 方法提交剩余的实例。 + + :param instance_model: Django模型类,用于调用 bulk_create 方法。 + :param batch_size: 批量保存的阈值,默认50。 + """ + def decorator(func): + cache = [] # 缓存实例的列表 + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal cache + + # 调用被装饰的函数,生成一个实例 + instance = func(*args, **kwargs) + if instance is None: + return None + + # 添加实例到缓存 + cache.append(instance) + print(f"Instance added to cache. Cache size: {len(cache)}") + + # 如果缓存大小达到批量保存阈值,执行保存 + if len(cache) >= batch_size: + print(f"Batch size reached. Saving {len(cache)} instances...") + instance_model.objects.bulk_create(cache) + cache.clear() + + return instance + + # 提交剩余实例的方法 + def commit(): + nonlocal cache + if cache: + print(f"Committing remaining {len(cache)} instances...") + instance_model.objects.bulk_create(cache) + cache.clear() + + wrapper.finish = commit + return wrapper + + return decorator + + +def bulk_update_decorator(instance_model, batch_size=50, update_fields=None): + """ + 装饰器,用于批量更新实例,并提供 `commit` 方法提交剩余的更新。 + + :param instance_model: Django模型类,用于调用 bulk_update 方法。 + :param batch_size: 批量更新的阈值,默认50。 + :param update_fields: 指定要更新的字段列表,默认为 None,表示更新所有字段。 + """ + def decorator(func): + cache = [] # 缓存待更新实例的列表 + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal cache + + # 调用被装饰的函数,获取一个需要更新的实例 + instance = func(*args, **kwargs) + if instance is None: + return None + + # 添加实例到缓存 + cache.append(instance) + print(f"Instance added to update cache. Cache size: {len(cache)}") + + # 如果缓存大小达到批量更新阈值,执行更新 + if len(cache) >= batch_size: + print(f"Batch size reached. Updating {len(cache)} instances...") + instance_model.objects.bulk_update(cache, update_fields) + cache.clear() + + return instance + + # 提交剩余更新的方法 + def commit(): + nonlocal cache + if cache: + print(f"Committing remaining {len(cache)} instances..., {update_fields}") + # with transaction.atomic(): + # for c in cache: + # o = instance_model.objects.get(id=str(c.id)) + # print("Origin: ", o.id, o.sudoers) + # o.sudoers = c.sudoers + # o.save() + # print("New: ", c.id, c.sudoers) + instance_model.objects.bulk_update(cache, update_fields) + # print("Committing remaining instances... done, ", cache[0].sudoers, cache[0].id, instance_model) + # print(instance_model.objects.get(id=str(cache[0].id)).sudoers) + cache.clear() + + # 将 commit 方法绑定到装饰后的函数 + wrapper.finish = commit + return wrapper + + return decorator