perf: update bulk update create

pull/14387/merge
ibuler 5 days ago
parent 6f149b7c11
commit 11975842f6

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.db import transaction
from django.shortcuts import get_object_or_404
from rest_framework import status
from rest_framework.decorators import action
@ -82,6 +83,7 @@ class GatheredAccountViewSet(OrgBulkModelViewSet):
'is_sync_account': False,
'name': 'Adhoc gather accounts: {}'.format(asset_id),
}
with transaction.atomic():
execution.save()
execution.start()
accounts = self.model.objects.filter(asset=asset)

@ -13,11 +13,3 @@ class AccountBasePlaybookManager(BasePlaybookManager):
@property
def platform_automation_methods(self):
return platform_automation_methods
def gen_report(self):
context = {
'execution': self.execution,
'summary': self.execution.summary,
'result': self.execution.result
}
return render_to_string(self.template_path, context)

@ -7,11 +7,15 @@ from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from xlsxwriter import Workbook
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy, ChangeSecretRecordStatusChoice
from accounts.const import (
AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy, ChangeSecretRecordStatusChoice
)
from accounts.models import ChangeSecretRecord, BaseAccountQuerySet
from accounts.notifications import ChangeSecretExecutionTaskMsg, ChangeSecretFailedMsg
from accounts.serializers import ChangeSecretRecordBackUpSerializer
from assets.const import HostTypes
from common.db.utils import safe_db_connection
from common.decorators import bulk_create_decorator
from common.utils import get_logger
from common.utils.file import encrypt_and_compress_zip_file
from common.utils.timezone import local_now_filename
@ -26,7 +30,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.record_map = self.execution.snapshot.get('record_map', {})
self.record_map = self.execution.snapshot.get('record_map', {}) # 这个是某个失败的记录重试
self.secret_type = self.execution.snapshot.get('secret_type')
self.secret_strategy = self.execution.snapshot.get(
'secret_strategy', SecretStrategy.custom
@ -36,6 +40,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
)
self.account_ids = self.execution.snapshot['accounts']
self.name_recorder_mapper = {} # 做个映射,方便后面处理
self.pending_add_records = []
@classmethod
def method_type(cls):
@ -52,18 +57,6 @@ class ChangeSecretManager(AccountBasePlaybookManager):
kwargs['regexp'] = '.*{}$'.format(secret.split()[2].strip())
return kwargs
def secret_generator(self, secret_type):
return SecretGenerator(
self.secret_strategy, secret_type,
self.execution.snapshot.get('password_rules')
)
def get_secret(self, secret_type):
if self.secret_strategy == SecretStrategy.custom:
return self.execution.snapshot['secret']
else:
return self.secret_generator(secret_type).get_secret()
def get_accounts(self, privilege_account) -> BaseAccountQuerySet | None:
if not privilege_account:
print('Not privilege account')
@ -81,71 +74,51 @@ class ChangeSecretManager(AccountBasePlaybookManager):
)
return accounts
def host_callback(
self, host, asset=None, account=None,
automation=None, path_dir=None, **kwargs
):
host = super().host_callback(
host, asset=asset, account=account, automation=automation,
path_dir=path_dir, **kwargs
)
if host.get('error'):
return host
host['check_conn_after_change'] = self.execution.snapshot.get('check_conn_after_change', True)
def gen_new_secret(self, account, path_dir):
private_key_path = None
if self.secret_type is None:
new_secret = account.secret
return new_secret, private_key_path
accounts = self.get_accounts(account)
error_msg = _("No pending accounts found")
if not accounts:
print(f'{asset}: {error_msg}')
return []
if self.secret_strategy == SecretStrategy.custom:
new_secret = self.execution.snapshot['secret']
else:
generator = SecretGenerator(
self.secret_strategy, self.secret_type,
self.execution.snapshot.get('password_rules')
)
new_secret = generator.get_secret()
records = []
inventory_hosts = []
if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY:
print(f'Windows {asset} does not support ssh key push')
return inventory_hosts
if account.secret_type == SecretType.SSH_KEY:
private_key_path = self.generate_private_key_path(new_secret, path_dir)
new_secret = self.generate_public_key(new_secret)
return new_secret, private_key_path
if asset.type == HostTypes.WINDOWS:
accounts = accounts.filter(secret_type=SecretType.PASSWORD)
def get_or_create_record(self, asset, account, new_secret, name):
asset_account_id = f'{asset.id}-{account.id}'
host['ssh_params'] = {}
for account in accounts:
h = deepcopy(host)
secret_type = account.secret_type
h['name'] += '(' + account.username + ')'
if self.secret_type is None:
new_secret = account.secret
if asset_account_id in self.record_map:
record_id = self.record_map[asset_account_id]
recorder = ChangeSecretRecord.objects.filter(id=record_id).first()
else:
new_secret = self.get_secret(secret_type)
recorder = self.create_record(asset, account, new_secret)
if new_secret is None:
print(f'new_secret is None, account: {account}')
continue
if recorder:
self.name_recorder_mapper[name] = recorder
asset_account_id = f'{asset.id}-{account.id}'
if asset_account_id not in self.record_map:
@bulk_create_decorator(ChangeSecretRecord)
def create_record(self, asset, account, new_secret):
recorder = ChangeSecretRecord(
asset=asset, account=account, execution=self.execution,
old_secret=account.secret, new_secret=new_secret,
comment=f'{account.username}@{asset.address}'
)
records.append(recorder)
else:
record_id = self.record_map[asset_account_id]
try:
recorder = ChangeSecretRecord.objects.get(id=record_id)
except ChangeSecretRecord.DoesNotExist:
print(f"Record {record_id} not found")
continue
self.name_recorder_mapper[h['name']] = recorder
private_key_path = None
if secret_type == SecretType.SSH_KEY:
private_key_path = self.generate_private_key_path(new_secret, path_dir)
new_secret = self.generate_public_key(new_secret)
return recorder
def gen_change_secret_inventory(self, host, account, new_secret, private_key_path, asset):
h = deepcopy(host)
secret_type = account.secret_type
h['name'] += '(' + account.username + ')'
h['ssh_params'].update(self.get_ssh_params(account, new_secret, secret_type))
h['account'] = {
'name': account.name,
@ -157,8 +130,40 @@ class ChangeSecretManager(AccountBasePlaybookManager):
}
if asset.platform.type == 'oracle':
h['account']['mode'] = 'sysdba' if account.privileged else None
return h
def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs):
host = super().host_callback(
host, asset=asset, account=account, automation=automation,
path_dir=path_dir, **kwargs
)
if host.get('error'):
return host
host['check_conn_after_change'] = self.execution.snapshot.get('check_conn_after_change', True)
host['ssh_params'] = {}
accounts = self.get_accounts(account)
error_msg = _("! No pending accounts found")
if not accounts:
print(f'{asset}: {error_msg}')
return []
if asset.type == HostTypes.WINDOWS:
accounts = accounts.filter(secret_type=SecretType.PASSWORD)
inventory_hosts = []
if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY:
print(f'! Windows {asset} does not support ssh key push')
return inventory_hosts
for account in accounts:
new_secret, private_key_path = self.gen_new_secret(account, path_dir)
h = self.gen_change_secret_inventory(host, account, new_secret, private_key_path, asset)
self.get_or_create_record(asset, account, new_secret, h['name'])
inventory_hosts.append(h)
ChangeSecretRecord.objects.bulk_create(records)
self.create_record.finish()
return inventory_hosts
def on_host_success(self, host, result):
@ -172,24 +177,13 @@ class ChangeSecretManager(AccountBasePlaybookManager):
if not account:
print("Account not found, deleted ?")
return
account.secret = recorder.new_secret
account.date_updated = timezone.now()
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
recorder.save()
with safe_db_connection():
recorder.save(update_fields=['status', 'date_finished'])
account.save(update_fields=['secret', 'version', 'date_updated'])
break
except Exception as e:
retry_count += 1
if retry_count == max_retries:
self.on_host_error(host, str(e), result)
else:
print(f'retry {retry_count} times for {host} recorder save error: {e}')
time.sleep(1)
def on_host_error(self, host, error, result):
recorder = self.name_recorder_mapper.get(host)
@ -222,21 +216,24 @@ class ChangeSecretManager(AccountBasePlaybookManager):
else:
failed += 1
total += 1
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
return summary
def run(self, *args, **kwargs):
if self.secret_type and not self.check_secret():
self.execution.status = 'success'
self.execution.date_finished = timezone.now()
self.execution.save()
return
super().run(*args, **kwargs)
def print_summary(self):
recorders = list(self.name_recorder_mapper.values())
summary = self.get_summary(recorders)
print(summary, end='')
print('\n\n' + '-' * 80)
plan_execution_end = _('Plan execution end')
print('{} {}\n'.format(plan_execution_end, local_now_filename()))
time_cost = _('Time cost')
print('{}: {}s'.format(time_cost, self.duration))
print(summary)
def send_report_if_need(self, *args, **kwargs):
if self.secret_type and not self.check_secret():
return
recorders = list(self.name_recorder_mapper.values())
if self.record_map:
return
@ -262,6 +259,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
if not recorders:
return
summary = self.get_summary(recorders)
self.send_recorder_mail(recipients, recorders, summary)
def send_recorder_mail(self, recipients, recorders, summary):

@ -88,7 +88,7 @@ def check_account_secrets(accounts, assets):
class CheckAccountManager(BaseManager):
batch_size=100
batch_size = 100
def __init__(self, execution):
super().__init__(execution)
@ -117,8 +117,14 @@ class CheckAccountManager(BaseManager):
for k, v in result.items():
self.result[k].extend(v)
def get_report_subject(self):
return "Check account report of %s" % self.execution.id
def get_report_template(self):
return 'accounts/check_account_report.html'
def print_summary(self):
tmpl = "\n---\nSummary: \nok: %s, weak password: %s, no secret: %s, using time: %ss" % (
self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.timedelta)
self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.duration)
)
print(tmpl)

@ -6,6 +6,7 @@ from accounts.const import AutomationTypes
from accounts.models import GatheredAccount, Account, AccountRisk
from assets.models import Asset
from common.const import ConfirmOrIgnore
from common.decorators import bulk_create_decorator, bulk_update_decorator
from common.utils import get_logger
from common.utils.strings import get_text_diff
from orgs.utils import tmp_to_org
@ -62,35 +63,18 @@ class AnalyseAccountRisk:
if not diff:
return
risks = []
for k, v in diff.items():
self.pending_add_risks.append(dict(
risks.append(dict(
asset=ori_account.asset, username=ori_account.username,
risk=k+'_changed', detail={'diff': v}
))
def perform_save_risks(self, risks):
# 提前取出来,避免每次都查数据库
assets = {r['asset'] for r in risks}
assets_risks = AccountRisk.objects.filter(asset__in=assets)
assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks}
for d in risks:
detail = d.pop('detail', {})
detail['datetime'] = self.now.isoformat()
key = f"{d['asset'].id}_{d['username']}_{d['risk']}"
found = assets_risks.get(key)
if not found:
r = AccountRisk(**d, details=[detail])
r.save()
continue
found.details.append(detail)
found.save(update_fields=['details'])
self.save_or_update_risks(risks)
def _analyse_datetime_changed(self, ori_account, d, asset, username):
basic = {'asset': asset, 'username': username}
risks = []
for item in self.datetime_check_items:
field = item['field']
risk = item['risk']
@ -105,41 +89,61 @@ class AnalyseAccountRisk:
continue
if date and date < timezone.now() - delta:
self.pending_add_risks.append(
risks.append(
dict(**basic, risk=risk, detail={'date': date.isoformat()})
)
def batch_analyse_risk(self, asset, ori_account, d, batch_size=20):
if not self.check_risk:
return
self.save_or_update_risks(risks)
if asset is None:
if self.pending_add_risks:
self.perform_save_risks(self.pending_add_risks)
self.pending_add_risks = []
def save_or_update_risks(self, risks):
# 提前取出来,避免每次都查数据库
assets = {r['asset'] for r in risks}
assets_risks = AccountRisk.objects.filter(asset__in=assets)
assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks}
for d in risks:
detail = d.pop('detail', {})
detail['datetime'] = self.now.isoformat()
key = f"{d['asset'].id}_{d['username']}_{d['risk']}"
found = assets_risks.get(key)
if not found:
self._create_risk(dict(**d, details=[detail]))
continue
found.details.append(detail)
self._update_risk(found)
@bulk_create_decorator(AccountRisk)
def _create_risk(self, data):
return AccountRisk(**data)
@bulk_update_decorator(AccountRisk, update_fields=['details'])
def _update_risk(self, account):
return account
def finish(self):
self._create_risk.finish()
self._update_risk.finish()
def analyse_risk(self, asset, ori_account, d):
if not self.check_risk:
return
basic = {'asset': asset, 'username': d['username']}
if ori_account:
self._analyse_item_changed(ori_account, d)
else:
self.pending_add_risks.append(
dict(**basic, risk='ghost')
)
self._create_risk(dict(**basic, risk='new_account'))
self._analyse_datetime_changed(ori_account, d, asset, d['username'])
if len(self.pending_add_risks) > batch_size:
self.batch_analyse_risk(None, None, {})
class GatherAccountsManager(AccountBasePlaybookManager):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.host_asset_mapper = {}
self.asset_account_info = {}
self.pending_add_accounts = []
self.pending_update_accounts = []
self.asset_usernames_mapper = defaultdict(set)
self.ori_asset_usernames = defaultdict(set)
self.ori_gathered_usernames = defaultdict(set)
@ -204,7 +208,7 @@ class GatherAccountsManager(AccountBasePlaybookManager):
ga_accounts = GatheredAccount.objects.filter(asset__in=assets)
for account in ga_accounts:
self.ori_gathered_usernames[account.asset].add(account.username)
key = '{}_{}'.format(account.asset.id, account.username)
key = '{}_{}'.format(account.asset_id, account.username)
self.ori_gathered_accounts_mapper[key] = account
def update_gather_accounts_status(self, asset):
@ -258,38 +262,25 @@ class GatherAccountsManager(AccountBasePlaybookManager):
# 资产上没有的,标识为为存在
queryset.exclude(username__in=ori_users).filter(present=False).update(present=True)
def batch_create_gathered_account(self, d, batch_size=20):
if d is None:
if self.pending_add_accounts:
GatheredAccount.objects.bulk_create(self.pending_add_accounts, ignore_conflicts=True)
self.pending_add_accounts = []
return
@bulk_create_decorator(GatheredAccount)
def create_gathered_account(self, d):
gathered_account = GatheredAccount()
for k, v in d.items():
setattr(gathered_account, k, v)
self.pending_add_accounts.append(gathered_account)
if len(self.pending_add_accounts) > batch_size:
self.batch_create_gathered_account(None)
def batch_update_gathered_account(self, ori_account, d, batch_size=20):
if not ori_account or d is None:
if self.pending_update_accounts:
GatheredAccount.objects.bulk_update(self.pending_update_accounts, [*diff_items])
self.pending_update_accounts = []
return
return gathered_account
@bulk_update_decorator(GatheredAccount, update_fields=diff_items)
def update_gathered_account(self, ori_account, d):
diff = get_items_diff(ori_account, d)
if diff:
if not diff:
return
for k in diff:
setattr(ori_account, k, d[k])
self.pending_update_accounts.append(ori_account)
return ori_account
if len(self.pending_update_accounts) > batch_size:
self.batch_update_gathered_account(None, None)
def do_run(self):
def do_run(self, *args, **kwargs):
super().do_run(*args, **kwargs)
self.prefetch_origin_account_usernames()
risk_analyser = AnalyseAccountRisk(self.check_risk)
for asset, accounts_data in self.asset_account_info.items():
@ -300,21 +291,20 @@ class GatherAccountsManager(AccountBasePlaybookManager):
ori_account = self.ori_gathered_accounts_mapper.get('{}_{}'.format(asset.id, username))
if not ori_account:
self.batch_create_gathered_account(d)
self.create_gathered_account(d)
else:
self.batch_update_gathered_account(ori_account, d)
risk_analyser.batch_analyse_risk(asset, ori_account, d)
self.update_gathered_account(ori_account, d)
risk_analyser.analyse_risk(asset, ori_account, d)
self.update_gather_accounts_status(asset)
GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account)
self.batch_create_gathered_account(None)
self.batch_update_gathered_account(None, None)
risk_analyser.batch_analyse_risk(None, None, {})
self.create_gathered_account.finish()
self.update_gathered_account.finish()
risk_analyser.finish()
def before_run(self):
super().before_run()
self.prefetch_origin_account_usernames()
def send_report_if_need(self):
pass
def generate_send_users_and_change_info(self):
recipients = self.execution.recipients

@ -98,60 +98,63 @@ class BaseManager:
def get_assets_group_by_platform(self):
return self.execution.all_assets_group_by_platform()
def before_run(self):
def pre_run(self):
self.execution.date_start = timezone.now()
self.execution.save(update_fields=['date_start'])
def get_report_subject(self):
return f'Automation {self.execution.id} finished'
def send_report_if_need(self):
recipients = self.execution.recipients
if not recipients:
return
report = self.gen_report()
report = transform(report)
subject = self.get_report_subject()
print("Send resport to: {}".format([str(r) for r in recipients]))
emails = [r.email for r in recipients if r.email]
send_mail_async(subject, report, emails, html_message=report)
def update_execution(self):
self.duration = int(time.time() - self.time_start)
self.execution.date_finished = timezone.now()
self.execution.duration = self.duration
self.execution.summary = self.summary
self.execution.result = self.result
self.execution.status = 'success'
with safe_db_connection():
self.execution.save(update_fields=['date_finished', 'duration', 'summary', 'result'])
self.execution.save()
def print_summary(self):
pass
def get_template_path(self):
def get_report_template(self):
raise NotImplementedError
def gen_report(self):
template_path = self.get_template_path()
context = {
def get_report_subject(self):
return f'Automation {self.execution.id} finished'
def get_report_context(self):
return {
'execution': self.execution,
'summary': self.execution.summary,
'result': self.execution.result
}
def send_report_if_need(self):
recipients = self.execution.recipients
if not recipients:
return
report = self.gen_report()
report = transform(report)
subject = self.get_report_subject()
emails = [r.email for r in recipients if r.email]
send_mail_async(subject, report, emails, html_message=report)
def gen_report(self):
template_path = self.get_report_template()
context = self.get_report_context()
data = render_to_string(template_path, context)
return data
def after_run(self):
def post_run(self):
self.update_execution()
self.print_summary()
self.send_report_if_need()
def run(self, *args, **kwargs):
self.before_run()
self.pre_run()
self.do_run(*args, **kwargs)
self.after_run()
self.post_run()
def do_run(self, *args, **kwargs):
raise NotImplementedError
@ -374,15 +377,19 @@ class BasePlaybookManager(BaseManager):
if settings.DEBUG_DEV:
print('host error: {} -> {}'.format(host, error))
def _on_host_success(self, host, result, error, detail):
self.on_host_success(host, result)
def _on_host_success(self, host, result, hosts):
self.on_host_success(host, result.get("ok", ''))
def _on_host_error(self, host, result, error, detail):
def _on_host_error(self, host, result, hosts):
error = hosts.get(host, '')
detail = result.get('failures', '') or result.get('dark', '')
self.on_host_error(host, error, detail)
def on_runner_success(self, runner, cb):
summary = cb.summary
for state, hosts in summary.items():
# 错误行为为host 是 dict ok 时是 list
if state == 'ok':
handler = self._on_host_success
elif state == 'skipped':
@ -392,9 +399,7 @@ class BasePlaybookManager(BaseManager):
for host in hosts:
result = cb.host_results.get(host)
error = hosts.get(host, '')
detail = result.get('failures', '') or result.get('dark', '')
handler(host, result, error, detail)
handler(host, result, hosts)
def on_runner_failed(self, runner, e):
print("Runner failed: {} {}".format(e, self))

@ -1,6 +1,6 @@
from contextlib import contextmanager
from django.db import connections, transaction
from django.db import connections, transaction, connection
from django.utils.encoding import force_str
from common.utils import get_logger, signer, crypto
@ -13,7 +13,7 @@ def get_object_if_need(model, pk):
try:
return model.objects.get(id=pk)
except model.DoesNotExist as e:
logger.error(f'DoesNotExist: <{model.__name__}:{pk}> not exist')
logger.error(f"DoesNotExist: <{model.__name__}:{pk}> not exist")
raise e
return pk
@ -26,8 +26,8 @@ def get_objects_if_need(model, pks):
if len(objs) != len(pks):
pks = set(pks)
exists_pks = {o.id for o in objs}
not_found_pks = ','.join(pks - exists_pks)
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>')
not_found_pks = ",".join(pks - exists_pks)
logger.error(f"DoesNotExist: <{model.__name__}: {not_found_pks}>")
return objs
return pks
@ -41,7 +41,7 @@ def get_objects(model, pks):
pks = set(pks)
exists_pks = {o.id for o in objs}
not_found_pks = pks - exists_pks
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>')
logger.error(f"DoesNotExist: <{model.__name__}: {not_found_pks}>")
return objs
@ -53,13 +53,17 @@ def close_old_connections():
@contextmanager
def safe_db_connection():
close_old_connections()
try:
close_old_connections() # 确保旧连接关闭
if connection.connection: # 如果连接已关闭,重新连接
connection.connect()
yield
close_old_connections()
finally:
close_old_connections() # 确保最终关闭连接
@contextmanager
def open_db_connection(alias='default'):
def open_db_connection(alias="default"):
connection = transaction.get_connection(alias)
try:
connection.connect()

@ -11,8 +11,8 @@ from functools import wraps
from django.db import transaction
from .utils import logger
from .db.utils import open_db_connection
from .utils import logger
def on_transaction_commit(func):
@ -294,3 +294,104 @@ def cached_method(ttl=20):
return wrapper
return decorator
def bulk_create_decorator(instance_model, batch_size=50):
"""
装饰器用于将实例批量保存并提供 `commit` 方法提交剩余的实例
:param instance_model: Django模型类用于调用 bulk_create 方法
:param batch_size: 批量保存的阈值默认50
"""
def decorator(func):
cache = [] # 缓存实例的列表
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal cache
# 调用被装饰的函数,生成一个实例
instance = func(*args, **kwargs)
if instance is None:
return None
# 添加实例到缓存
cache.append(instance)
print(f"Instance added to cache. Cache size: {len(cache)}")
# 如果缓存大小达到批量保存阈值,执行保存
if len(cache) >= batch_size:
print(f"Batch size reached. Saving {len(cache)} instances...")
instance_model.objects.bulk_create(cache)
cache.clear()
return instance
# 提交剩余实例的方法
def commit():
nonlocal cache
if cache:
print(f"Committing remaining {len(cache)} instances...")
instance_model.objects.bulk_create(cache)
cache.clear()
wrapper.finish = commit
return wrapper
return decorator
def bulk_update_decorator(instance_model, batch_size=50, update_fields=None):
"""
装饰器用于批量更新实例并提供 `commit` 方法提交剩余的更新
:param instance_model: Django模型类用于调用 bulk_update 方法
:param batch_size: 批量更新的阈值默认50
:param update_fields: 指定要更新的字段列表默认为 None表示更新所有字段
"""
def decorator(func):
cache = [] # 缓存待更新实例的列表
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal cache
# 调用被装饰的函数,获取一个需要更新的实例
instance = func(*args, **kwargs)
if instance is None:
return None
# 添加实例到缓存
cache.append(instance)
print(f"Instance added to update cache. Cache size: {len(cache)}")
# 如果缓存大小达到批量更新阈值,执行更新
if len(cache) >= batch_size:
print(f"Batch size reached. Updating {len(cache)} instances...")
instance_model.objects.bulk_update(cache, update_fields)
cache.clear()
return instance
# 提交剩余更新的方法
def commit():
nonlocal cache
if cache:
print(f"Committing remaining {len(cache)} instances..., {update_fields}")
# with transaction.atomic():
# for c in cache:
# o = instance_model.objects.get(id=str(c.id))
# print("Origin: ", o.id, o.sudoers)
# o.sudoers = c.sudoers
# o.save()
# print("New: ", c.id, c.sudoers)
instance_model.objects.bulk_update(cache, update_fields)
# print("Committing remaining instances... done, ", cache[0].sudoers, cache[0].id, instance_model)
# print(instance_model.objects.get(id=str(cache[0].id)).sudoers)
cache.clear()
# 将 commit 方法绑定到装饰后的函数
wrapper.finish = commit
return wrapper
return decorator

Loading…
Cancel
Save