perf: update bulk update create

pull/14387/merge
ibuler 5 days ago
parent 6f149b7c11
commit 11975842f6

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db import transaction
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from rest_framework import status from rest_framework import status
from rest_framework.decorators import action from rest_framework.decorators import action
@ -82,7 +83,8 @@ class GatheredAccountViewSet(OrgBulkModelViewSet):
'is_sync_account': False, 'is_sync_account': False,
'name': 'Adhoc gather accounts: {}'.format(asset_id), 'name': 'Adhoc gather accounts: {}'.format(asset_id),
} }
execution.save() with transaction.atomic():
execution.save()
execution.start() execution.start()
accounts = self.model.objects.filter(asset=asset) accounts = self.model.objects.filter(asset=asset)
return self.get_paginated_response_from_queryset(accounts) return self.get_paginated_response_from_queryset(accounts)

@ -13,11 +13,3 @@ class AccountBasePlaybookManager(BasePlaybookManager):
@property @property
def platform_automation_methods(self): def platform_automation_methods(self):
return platform_automation_methods return platform_automation_methods
def gen_report(self):
context = {
'execution': self.execution,
'summary': self.execution.summary,
'result': self.execution.result
}
return render_to_string(self.template_path, context)

@ -7,11 +7,15 @@ from django.utils import timezone
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from xlsxwriter import Workbook from xlsxwriter import Workbook
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy, ChangeSecretRecordStatusChoice from accounts.const import (
AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy, ChangeSecretRecordStatusChoice
)
from accounts.models import ChangeSecretRecord, BaseAccountQuerySet from accounts.models import ChangeSecretRecord, BaseAccountQuerySet
from accounts.notifications import ChangeSecretExecutionTaskMsg, ChangeSecretFailedMsg from accounts.notifications import ChangeSecretExecutionTaskMsg, ChangeSecretFailedMsg
from accounts.serializers import ChangeSecretRecordBackUpSerializer from accounts.serializers import ChangeSecretRecordBackUpSerializer
from assets.const import HostTypes from assets.const import HostTypes
from common.db.utils import safe_db_connection
from common.decorators import bulk_create_decorator
from common.utils import get_logger from common.utils import get_logger
from common.utils.file import encrypt_and_compress_zip_file from common.utils.file import encrypt_and_compress_zip_file
from common.utils.timezone import local_now_filename from common.utils.timezone import local_now_filename
@ -26,7 +30,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.record_map = self.execution.snapshot.get('record_map', {}) self.record_map = self.execution.snapshot.get('record_map', {}) # 这个是某个失败的记录重试
self.secret_type = self.execution.snapshot.get('secret_type') self.secret_type = self.execution.snapshot.get('secret_type')
self.secret_strategy = self.execution.snapshot.get( self.secret_strategy = self.execution.snapshot.get(
'secret_strategy', SecretStrategy.custom 'secret_strategy', SecretStrategy.custom
@ -36,6 +40,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
) )
self.account_ids = self.execution.snapshot['accounts'] self.account_ids = self.execution.snapshot['accounts']
self.name_recorder_mapper = {} # 做个映射,方便后面处理 self.name_recorder_mapper = {} # 做个映射,方便后面处理
self.pending_add_records = []
@classmethod @classmethod
def method_type(cls): def method_type(cls):
@ -52,18 +57,6 @@ class ChangeSecretManager(AccountBasePlaybookManager):
kwargs['regexp'] = '.*{}$'.format(secret.split()[2].strip()) kwargs['regexp'] = '.*{}$'.format(secret.split()[2].strip())
return kwargs return kwargs
def secret_generator(self, secret_type):
return SecretGenerator(
self.secret_strategy, secret_type,
self.execution.snapshot.get('password_rules')
)
def get_secret(self, secret_type):
if self.secret_strategy == SecretStrategy.custom:
return self.execution.snapshot['secret']
else:
return self.secret_generator(secret_type).get_secret()
def get_accounts(self, privilege_account) -> BaseAccountQuerySet | None: def get_accounts(self, privilege_account) -> BaseAccountQuerySet | None:
if not privilege_account: if not privilege_account:
print('Not privilege account') print('Not privilege account')
@ -81,10 +74,65 @@ class ChangeSecretManager(AccountBasePlaybookManager):
) )
return accounts return accounts
def host_callback( def gen_new_secret(self, account, path_dir):
self, host, asset=None, account=None, private_key_path = None
automation=None, path_dir=None, **kwargs if self.secret_type is None:
): new_secret = account.secret
return new_secret, private_key_path
if self.secret_strategy == SecretStrategy.custom:
new_secret = self.execution.snapshot['secret']
else:
generator = SecretGenerator(
self.secret_strategy, self.secret_type,
self.execution.snapshot.get('password_rules')
)
new_secret = generator.get_secret()
if account.secret_type == SecretType.SSH_KEY:
private_key_path = self.generate_private_key_path(new_secret, path_dir)
new_secret = self.generate_public_key(new_secret)
return new_secret, private_key_path
def get_or_create_record(self, asset, account, new_secret, name):
asset_account_id = f'{asset.id}-{account.id}'
if asset_account_id in self.record_map:
record_id = self.record_map[asset_account_id]
recorder = ChangeSecretRecord.objects.filter(id=record_id).first()
else:
recorder = self.create_record(asset, account, new_secret)
if recorder:
self.name_recorder_mapper[name] = recorder
@bulk_create_decorator(ChangeSecretRecord)
def create_record(self, asset, account, new_secret):
recorder = ChangeSecretRecord(
asset=asset, account=account, execution=self.execution,
old_secret=account.secret, new_secret=new_secret,
comment=f'{account.username}@{asset.address}'
)
return recorder
def gen_change_secret_inventory(self, host, account, new_secret, private_key_path, asset):
h = deepcopy(host)
secret_type = account.secret_type
h['name'] += '(' + account.username + ')'
h['ssh_params'].update(self.get_ssh_params(account, new_secret, secret_type))
h['account'] = {
'name': account.name,
'username': account.username,
'secret_type': secret_type,
'secret': account.escape_jinja2_syntax(new_secret),
'private_key_path': private_key_path,
'become': account.get_ansible_become_auth(),
}
if asset.platform.type == 'oracle':
h['account']['mode'] = 'sysdba' if account.privileged else None
return h
def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs):
host = super().host_callback( host = super().host_callback(
host, asset=asset, account=account, automation=automation, host, asset=asset, account=account, automation=automation,
path_dir=path_dir, **kwargs path_dir=path_dir, **kwargs
@ -93,72 +141,29 @@ class ChangeSecretManager(AccountBasePlaybookManager):
return host return host
host['check_conn_after_change'] = self.execution.snapshot.get('check_conn_after_change', True) host['check_conn_after_change'] = self.execution.snapshot.get('check_conn_after_change', True)
host['ssh_params'] = {}
accounts = self.get_accounts(account) accounts = self.get_accounts(account)
error_msg = _("No pending accounts found") error_msg = _("! No pending accounts found")
if not accounts: if not accounts:
print(f'{asset}: {error_msg}') print(f'{asset}: {error_msg}')
return [] return []
records = [] if asset.type == HostTypes.WINDOWS:
accounts = accounts.filter(secret_type=SecretType.PASSWORD)
inventory_hosts = [] inventory_hosts = []
if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY: if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY:
print(f'Windows {asset} does not support ssh key push') print(f'! Windows {asset} does not support ssh key push')
return inventory_hosts return inventory_hosts
if asset.type == HostTypes.WINDOWS:
accounts = accounts.filter(secret_type=SecretType.PASSWORD)
host['ssh_params'] = {}
for account in accounts: for account in accounts:
h = deepcopy(host) new_secret, private_key_path = self.gen_new_secret(account, path_dir)
secret_type = account.secret_type h = self.gen_change_secret_inventory(host, account, new_secret, private_key_path, asset)
h['name'] += '(' + account.username + ')' self.get_or_create_record(asset, account, new_secret, h['name'])
if self.secret_type is None:
new_secret = account.secret
else:
new_secret = self.get_secret(secret_type)
if new_secret is None:
print(f'new_secret is None, account: {account}')
continue
asset_account_id = f'{asset.id}-{account.id}'
if asset_account_id not in self.record_map:
recorder = ChangeSecretRecord(
asset=asset, account=account, execution=self.execution,
old_secret=account.secret, new_secret=new_secret,
comment=f'{account.username}@{asset.address}'
)
records.append(recorder)
else:
record_id = self.record_map[asset_account_id]
try:
recorder = ChangeSecretRecord.objects.get(id=record_id)
except ChangeSecretRecord.DoesNotExist:
print(f"Record {record_id} not found")
continue
self.name_recorder_mapper[h['name']] = recorder
private_key_path = None
if secret_type == SecretType.SSH_KEY:
private_key_path = self.generate_private_key_path(new_secret, path_dir)
new_secret = self.generate_public_key(new_secret)
h['ssh_params'].update(self.get_ssh_params(account, new_secret, secret_type))
h['account'] = {
'name': account.name,
'username': account.username,
'secret_type': secret_type,
'secret': account.escape_jinja2_syntax(new_secret),
'private_key_path': private_key_path,
'become': account.get_ansible_become_auth(),
}
if asset.platform.type == 'oracle':
h['account']['mode'] = 'sysdba' if account.privileged else None
inventory_hosts.append(h) inventory_hosts.append(h)
ChangeSecretRecord.objects.bulk_create(records)
self.create_record.finish()
return inventory_hosts return inventory_hosts
def on_host_success(self, host, result): def on_host_success(self, host, result):
@ -172,24 +177,13 @@ class ChangeSecretManager(AccountBasePlaybookManager):
if not account: if not account:
print("Account not found, deleted ?") print("Account not found, deleted ?")
return return
account.secret = recorder.new_secret account.secret = recorder.new_secret
account.date_updated = timezone.now() account.date_updated = timezone.now()
max_retries = 3 with safe_db_connection():
retry_count = 0 recorder.save(update_fields=['status', 'date_finished'])
account.save(update_fields=['secret', 'version', 'date_updated'])
while retry_count < max_retries:
try:
recorder.save()
account.save(update_fields=['secret', 'version', 'date_updated'])
break
except Exception as e:
retry_count += 1
if retry_count == max_retries:
self.on_host_error(host, str(e), result)
else:
print(f'retry {retry_count} times for {host} recorder save error: {e}')
time.sleep(1)
def on_host_error(self, host, error, result): def on_host_error(self, host, error, result):
recorder = self.name_recorder_mapper.get(host) recorder = self.name_recorder_mapper.get(host)
@ -222,21 +216,24 @@ class ChangeSecretManager(AccountBasePlaybookManager):
else: else:
failed += 1 failed += 1
total += 1 total += 1
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total) summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
return summary return summary
def run(self, *args, **kwargs): def print_summary(self):
if self.secret_type and not self.check_secret():
self.execution.status = 'success'
self.execution.date_finished = timezone.now()
self.execution.save()
return
super().run(*args, **kwargs)
recorders = list(self.name_recorder_mapper.values()) recorders = list(self.name_recorder_mapper.values())
summary = self.get_summary(recorders) summary = self.get_summary(recorders)
print(summary, end='') print('\n\n' + '-' * 80)
plan_execution_end = _('Plan execution end')
print('{} {}\n'.format(plan_execution_end, local_now_filename()))
time_cost = _('Time cost')
print('{}: {}s'.format(time_cost, self.duration))
print(summary)
def send_report_if_need(self, *args, **kwargs):
if self.secret_type and not self.check_secret():
return
recorders = list(self.name_recorder_mapper.values())
if self.record_map: if self.record_map:
return return
@ -262,6 +259,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
if not recorders: if not recorders:
return return
summary = self.get_summary(recorders)
self.send_recorder_mail(recipients, recorders, summary) self.send_recorder_mail(recipients, recorders, summary)
def send_recorder_mail(self, recipients, recorders, summary): def send_recorder_mail(self, recipients, recorders, summary):

@ -88,7 +88,7 @@ def check_account_secrets(accounts, assets):
class CheckAccountManager(BaseManager): class CheckAccountManager(BaseManager):
batch_size=100 batch_size = 100
def __init__(self, execution): def __init__(self, execution):
super().__init__(execution) super().__init__(execution)
@ -117,8 +117,14 @@ class CheckAccountManager(BaseManager):
for k, v in result.items(): for k, v in result.items():
self.result[k].extend(v) self.result[k].extend(v)
def get_report_subject(self):
return "Check account report of %s" % self.execution.id
def get_report_template(self):
return 'accounts/check_account_report.html'
def print_summary(self): def print_summary(self):
tmpl = "\n---\nSummary: \nok: %s, weak password: %s, no secret: %s, using time: %ss" % ( tmpl = "\n---\nSummary: \nok: %s, weak password: %s, no secret: %s, using time: %ss" % (
self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.timedelta) self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.duration)
) )
print(tmpl) print(tmpl)

@ -6,6 +6,7 @@ from accounts.const import AutomationTypes
from accounts.models import GatheredAccount, Account, AccountRisk from accounts.models import GatheredAccount, Account, AccountRisk
from assets.models import Asset from assets.models import Asset
from common.const import ConfirmOrIgnore from common.const import ConfirmOrIgnore
from common.decorators import bulk_create_decorator, bulk_update_decorator
from common.utils import get_logger from common.utils import get_logger
from common.utils.strings import get_text_diff from common.utils.strings import get_text_diff
from orgs.utils import tmp_to_org from orgs.utils import tmp_to_org
@ -62,35 +63,18 @@ class AnalyseAccountRisk:
if not diff: if not diff:
return return
risks = []
for k, v in diff.items(): for k, v in diff.items():
self.pending_add_risks.append(dict( risks.append(dict(
asset=ori_account.asset, username=ori_account.username, asset=ori_account.asset, username=ori_account.username,
risk=k+'_changed', detail={'diff': v} risk=k+'_changed', detail={'diff': v}
)) ))
self.save_or_update_risks(risks)
def perform_save_risks(self, risks):
# 提前取出来,避免每次都查数据库
assets = {r['asset'] for r in risks}
assets_risks = AccountRisk.objects.filter(asset__in=assets)
assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks}
for d in risks:
detail = d.pop('detail', {})
detail['datetime'] = self.now.isoformat()
key = f"{d['asset'].id}_{d['username']}_{d['risk']}"
found = assets_risks.get(key)
if not found:
r = AccountRisk(**d, details=[detail])
r.save()
continue
found.details.append(detail)
found.save(update_fields=['details'])
def _analyse_datetime_changed(self, ori_account, d, asset, username): def _analyse_datetime_changed(self, ori_account, d, asset, username):
basic = {'asset': asset, 'username': username} basic = {'asset': asset, 'username': username}
risks = []
for item in self.datetime_check_items: for item in self.datetime_check_items:
field = item['field'] field = item['field']
risk = item['risk'] risk = item['risk']
@ -105,41 +89,61 @@ class AnalyseAccountRisk:
continue continue
if date and date < timezone.now() - delta: if date and date < timezone.now() - delta:
self.pending_add_risks.append( risks.append(
dict(**basic, risk=risk, detail={'date': date.isoformat()}) dict(**basic, risk=risk, detail={'date': date.isoformat()})
) )
def batch_analyse_risk(self, asset, ori_account, d, batch_size=20): self.save_or_update_risks(risks)
if not self.check_risk:
return def save_or_update_risks(self, risks):
# 提前取出来,避免每次都查数据库
assets = {r['asset'] for r in risks}
assets_risks = AccountRisk.objects.filter(asset__in=assets)
assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks}
if asset is None: for d in risks:
if self.pending_add_risks: detail = d.pop('detail', {})
self.perform_save_risks(self.pending_add_risks) detail['datetime'] = self.now.isoformat()
self.pending_add_risks = [] key = f"{d['asset'].id}_{d['username']}_{d['risk']}"
found = assets_risks.get(key)
if not found:
self._create_risk(dict(**d, details=[detail]))
continue
found.details.append(detail)
self._update_risk(found)
@bulk_create_decorator(AccountRisk)
def _create_risk(self, data):
return AccountRisk(**data)
@bulk_update_decorator(AccountRisk, update_fields=['details'])
def _update_risk(self, account):
return account
def finish(self):
self._create_risk.finish()
self._update_risk.finish()
def analyse_risk(self, asset, ori_account, d):
if not self.check_risk:
return return
basic = {'asset': asset, 'username': d['username']} basic = {'asset': asset, 'username': d['username']}
if ori_account: if ori_account:
self._analyse_item_changed(ori_account, d) self._analyse_item_changed(ori_account, d)
else: else:
self.pending_add_risks.append( self._create_risk(dict(**basic, risk='new_account'))
dict(**basic, risk='ghost')
)
self._analyse_datetime_changed(ori_account, d, asset, d['username']) self._analyse_datetime_changed(ori_account, d, asset, d['username'])
if len(self.pending_add_risks) > batch_size:
self.batch_analyse_risk(None, None, {})
class GatherAccountsManager(AccountBasePlaybookManager): class GatherAccountsManager(AccountBasePlaybookManager):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.host_asset_mapper = {} self.host_asset_mapper = {}
self.asset_account_info = {} self.asset_account_info = {}
self.pending_add_accounts = []
self.pending_update_accounts = []
self.asset_usernames_mapper = defaultdict(set) self.asset_usernames_mapper = defaultdict(set)
self.ori_asset_usernames = defaultdict(set) self.ori_asset_usernames = defaultdict(set)
self.ori_gathered_usernames = defaultdict(set) self.ori_gathered_usernames = defaultdict(set)
@ -204,7 +208,7 @@ class GatherAccountsManager(AccountBasePlaybookManager):
ga_accounts = GatheredAccount.objects.filter(asset__in=assets) ga_accounts = GatheredAccount.objects.filter(asset__in=assets)
for account in ga_accounts: for account in ga_accounts:
self.ori_gathered_usernames[account.asset].add(account.username) self.ori_gathered_usernames[account.asset].add(account.username)
key = '{}_{}'.format(account.asset.id, account.username) key = '{}_{}'.format(account.asset_id, account.username)
self.ori_gathered_accounts_mapper[key] = account self.ori_gathered_accounts_mapper[key] = account
def update_gather_accounts_status(self, asset): def update_gather_accounts_status(self, asset):
@ -258,38 +262,25 @@ class GatherAccountsManager(AccountBasePlaybookManager):
# 资产上没有的,标识为为存在 # 资产上没有的,标识为为存在
queryset.exclude(username__in=ori_users).filter(present=False).update(present=True) queryset.exclude(username__in=ori_users).filter(present=False).update(present=True)
def batch_create_gathered_account(self, d, batch_size=20): @bulk_create_decorator(GatheredAccount)
if d is None: def create_gathered_account(self, d):
if self.pending_add_accounts:
GatheredAccount.objects.bulk_create(self.pending_add_accounts, ignore_conflicts=True)
self.pending_add_accounts = []
return
gathered_account = GatheredAccount() gathered_account = GatheredAccount()
for k, v in d.items(): for k, v in d.items():
setattr(gathered_account, k, v) setattr(gathered_account, k, v)
self.pending_add_accounts.append(gathered_account) return gathered_account
if len(self.pending_add_accounts) > batch_size:
self.batch_create_gathered_account(None)
def batch_update_gathered_account(self, ori_account, d, batch_size=20):
if not ori_account or d is None:
if self.pending_update_accounts:
GatheredAccount.objects.bulk_update(self.pending_update_accounts, [*diff_items])
self.pending_update_accounts = []
return
@bulk_update_decorator(GatheredAccount, update_fields=diff_items)
def update_gathered_account(self, ori_account, d):
diff = get_items_diff(ori_account, d) diff = get_items_diff(ori_account, d)
if diff: if not diff:
for k in diff: return
setattr(ori_account, k, d[k]) for k in diff:
self.pending_update_accounts.append(ori_account) setattr(ori_account, k, d[k])
return ori_account
if len(self.pending_update_accounts) > batch_size:
self.batch_update_gathered_account(None, None)
def do_run(self): def do_run(self, *args, **kwargs):
super().do_run(*args, **kwargs)
self.prefetch_origin_account_usernames()
risk_analyser = AnalyseAccountRisk(self.check_risk) risk_analyser = AnalyseAccountRisk(self.check_risk)
for asset, accounts_data in self.asset_account_info.items(): for asset, accounts_data in self.asset_account_info.items():
@ -300,21 +291,20 @@ class GatherAccountsManager(AccountBasePlaybookManager):
ori_account = self.ori_gathered_accounts_mapper.get('{}_{}'.format(asset.id, username)) ori_account = self.ori_gathered_accounts_mapper.get('{}_{}'.format(asset.id, username))
if not ori_account: if not ori_account:
self.batch_create_gathered_account(d) self.create_gathered_account(d)
else: else:
self.batch_update_gathered_account(ori_account, d) self.update_gathered_account(ori_account, d)
risk_analyser.batch_analyse_risk(asset, ori_account, d) risk_analyser.analyse_risk(asset, ori_account, d)
self.update_gather_accounts_status(asset) self.update_gather_accounts_status(asset)
GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account) GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account)
self.batch_create_gathered_account(None) self.create_gathered_account.finish()
self.batch_update_gathered_account(None, None) self.update_gathered_account.finish()
risk_analyser.batch_analyse_risk(None, None, {}) risk_analyser.finish()
def before_run(self): def send_report_if_need(self):
super().before_run() pass
self.prefetch_origin_account_usernames()
def generate_send_users_and_change_info(self): def generate_send_users_and_change_info(self):
recipients = self.execution.recipients recipients = self.execution.recipients

@ -98,60 +98,63 @@ class BaseManager:
def get_assets_group_by_platform(self): def get_assets_group_by_platform(self):
return self.execution.all_assets_group_by_platform() return self.execution.all_assets_group_by_platform()
def before_run(self): def pre_run(self):
self.execution.date_start = timezone.now() self.execution.date_start = timezone.now()
self.execution.save(update_fields=['date_start']) self.execution.save(update_fields=['date_start'])
def get_report_subject(self):
return f'Automation {self.execution.id} finished'
def send_report_if_need(self):
recipients = self.execution.recipients
if not recipients:
return
report = self.gen_report()
report = transform(report)
subject = self.get_report_subject()
print("Send resport to: {}".format([str(r) for r in recipients]))
emails = [r.email for r in recipients if r.email]
send_mail_async(subject, report, emails, html_message=report)
def update_execution(self): def update_execution(self):
self.duration = int(time.time() - self.time_start) self.duration = int(time.time() - self.time_start)
self.execution.date_finished = timezone.now() self.execution.date_finished = timezone.now()
self.execution.duration = self.duration self.execution.duration = self.duration
self.execution.summary = self.summary self.execution.summary = self.summary
self.execution.result = self.result self.execution.result = self.result
self.execution.status = 'success'
with safe_db_connection(): with safe_db_connection():
self.execution.save(update_fields=['date_finished', 'duration', 'summary', 'result']) self.execution.save()
def print_summary(self): def print_summary(self):
pass pass
def get_template_path(self): def get_report_template(self):
raise NotImplementedError raise NotImplementedError
def gen_report(self): def get_report_subject(self):
template_path = self.get_template_path() return f'Automation {self.execution.id} finished'
context = {
def get_report_context(self):
return {
'execution': self.execution, 'execution': self.execution,
'summary': self.execution.summary, 'summary': self.execution.summary,
'result': self.execution.result 'result': self.execution.result
} }
def send_report_if_need(self):
recipients = self.execution.recipients
if not recipients:
return
report = self.gen_report()
report = transform(report)
subject = self.get_report_subject()
emails = [r.email for r in recipients if r.email]
send_mail_async(subject, report, emails, html_message=report)
def gen_report(self):
template_path = self.get_report_template()
context = self.get_report_context()
data = render_to_string(template_path, context) data = render_to_string(template_path, context)
return data return data
def after_run(self): def post_run(self):
self.update_execution() self.update_execution()
self.print_summary() self.print_summary()
self.send_report_if_need() self.send_report_if_need()
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
self.before_run() self.pre_run()
self.do_run(*args, **kwargs) self.do_run(*args, **kwargs)
self.after_run() self.post_run()
def do_run(self, *args, **kwargs): def do_run(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@ -374,15 +377,19 @@ class BasePlaybookManager(BaseManager):
if settings.DEBUG_DEV: if settings.DEBUG_DEV:
print('host error: {} -> {}'.format(host, error)) print('host error: {} -> {}'.format(host, error))
def _on_host_success(self, host, result, error, detail): def _on_host_success(self, host, result, hosts):
self.on_host_success(host, result) self.on_host_success(host, result.get("ok", ''))
def _on_host_error(self, host, result, error, detail): def _on_host_error(self, host, result, hosts):
error = hosts.get(host, '')
detail = result.get('failures', '') or result.get('dark', '')
self.on_host_error(host, error, detail) self.on_host_error(host, error, detail)
def on_runner_success(self, runner, cb): def on_runner_success(self, runner, cb):
summary = cb.summary summary = cb.summary
for state, hosts in summary.items(): for state, hosts in summary.items():
# 错误行为为host 是 dict ok 时是 list
if state == 'ok': if state == 'ok':
handler = self._on_host_success handler = self._on_host_success
elif state == 'skipped': elif state == 'skipped':
@ -392,9 +399,7 @@ class BasePlaybookManager(BaseManager):
for host in hosts: for host in hosts:
result = cb.host_results.get(host) result = cb.host_results.get(host)
error = hosts.get(host, '') handler(host, result, hosts)
detail = result.get('failures', '') or result.get('dark', '')
handler(host, result, error, detail)
def on_runner_failed(self, runner, e): def on_runner_failed(self, runner, e):
print("Runner failed: {} {}".format(e, self)) print("Runner failed: {} {}".format(e, self))

@ -1,6 +1,6 @@
from contextlib import contextmanager from contextlib import contextmanager
from django.db import connections, transaction from django.db import connections, transaction, connection
from django.utils.encoding import force_str from django.utils.encoding import force_str
from common.utils import get_logger, signer, crypto from common.utils import get_logger, signer, crypto
@ -13,7 +13,7 @@ def get_object_if_need(model, pk):
try: try:
return model.objects.get(id=pk) return model.objects.get(id=pk)
except model.DoesNotExist as e: except model.DoesNotExist as e:
logger.error(f'DoesNotExist: <{model.__name__}:{pk}> not exist') logger.error(f"DoesNotExist: <{model.__name__}:{pk}> not exist")
raise e raise e
return pk return pk
@ -26,8 +26,8 @@ def get_objects_if_need(model, pks):
if len(objs) != len(pks): if len(objs) != len(pks):
pks = set(pks) pks = set(pks)
exists_pks = {o.id for o in objs} exists_pks = {o.id for o in objs}
not_found_pks = ','.join(pks - exists_pks) not_found_pks = ",".join(pks - exists_pks)
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>') logger.error(f"DoesNotExist: <{model.__name__}: {not_found_pks}>")
return objs return objs
return pks return pks
@ -41,7 +41,7 @@ def get_objects(model, pks):
pks = set(pks) pks = set(pks)
exists_pks = {o.id for o in objs} exists_pks = {o.id for o in objs}
not_found_pks = pks - exists_pks not_found_pks = pks - exists_pks
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>') logger.error(f"DoesNotExist: <{model.__name__}: {not_found_pks}>")
return objs return objs
@ -53,13 +53,17 @@ def close_old_connections():
@contextmanager @contextmanager
def safe_db_connection(): def safe_db_connection():
close_old_connections() try:
yield close_old_connections() # 确保旧连接关闭
close_old_connections() if connection.connection: # 如果连接已关闭,重新连接
connection.connect()
yield
finally:
close_old_connections() # 确保最终关闭连接
@contextmanager @contextmanager
def open_db_connection(alias='default'): def open_db_connection(alias="default"):
connection = transaction.get_connection(alias) connection = transaction.get_connection(alias)
try: try:
connection.connect() connection.connect()

@ -11,8 +11,8 @@ from functools import wraps
from django.db import transaction from django.db import transaction
from .utils import logger
from .db.utils import open_db_connection from .db.utils import open_db_connection
from .utils import logger
def on_transaction_commit(func): def on_transaction_commit(func):
@ -294,3 +294,104 @@ def cached_method(ttl=20):
return wrapper return wrapper
return decorator return decorator
def bulk_create_decorator(instance_model, batch_size=50):
"""
装饰器用于将实例批量保存并提供 `commit` 方法提交剩余的实例
:param instance_model: Django模型类用于调用 bulk_create 方法
:param batch_size: 批量保存的阈值默认50
"""
def decorator(func):
cache = [] # 缓存实例的列表
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal cache
# 调用被装饰的函数,生成一个实例
instance = func(*args, **kwargs)
if instance is None:
return None
# 添加实例到缓存
cache.append(instance)
print(f"Instance added to cache. Cache size: {len(cache)}")
# 如果缓存大小达到批量保存阈值,执行保存
if len(cache) >= batch_size:
print(f"Batch size reached. Saving {len(cache)} instances...")
instance_model.objects.bulk_create(cache)
cache.clear()
return instance
# 提交剩余实例的方法
def commit():
nonlocal cache
if cache:
print(f"Committing remaining {len(cache)} instances...")
instance_model.objects.bulk_create(cache)
cache.clear()
wrapper.finish = commit
return wrapper
return decorator
def bulk_update_decorator(instance_model, batch_size=50, update_fields=None):
"""
装饰器用于批量更新实例并提供 `commit` 方法提交剩余的更新
:param instance_model: Django模型类用于调用 bulk_update 方法
:param batch_size: 批量更新的阈值默认50
:param update_fields: 指定要更新的字段列表默认为 None表示更新所有字段
"""
def decorator(func):
cache = [] # 缓存待更新实例的列表
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal cache
# 调用被装饰的函数,获取一个需要更新的实例
instance = func(*args, **kwargs)
if instance is None:
return None
# 添加实例到缓存
cache.append(instance)
print(f"Instance added to update cache. Cache size: {len(cache)}")
# 如果缓存大小达到批量更新阈值,执行更新
if len(cache) >= batch_size:
print(f"Batch size reached. Updating {len(cache)} instances...")
instance_model.objects.bulk_update(cache, update_fields)
cache.clear()
return instance
# 提交剩余更新的方法
def commit():
nonlocal cache
if cache:
print(f"Committing remaining {len(cache)} instances..., {update_fields}")
# with transaction.atomic():
# for c in cache:
# o = instance_model.objects.get(id=str(c.id))
# print("Origin: ", o.id, o.sudoers)
# o.sudoers = c.sudoers
# o.save()
# print("New: ", c.id, c.sudoers)
instance_model.objects.bulk_update(cache, update_fields)
# print("Committing remaining instances... done, ", cache[0].sudoers, cache[0].id, instance_model)
# print(instance_model.objects.get(id=str(cache[0].id)).sudoers)
cache.clear()
# 将 commit 方法绑定到装饰后的函数
wrapper.finish = commit
return wrapper
return decorator

Loading…
Cancel
Save