mirror of https://github.com/jumpserver/jumpserver
perf: update bulk update create
parent
6f149b7c11
commit
11975842f6
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.db import transaction
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import action
|
||||
|
@ -82,7 +83,8 @@ class GatheredAccountViewSet(OrgBulkModelViewSet):
|
|||
'is_sync_account': False,
|
||||
'name': 'Adhoc gather accounts: {}'.format(asset_id),
|
||||
}
|
||||
execution.save()
|
||||
with transaction.atomic():
|
||||
execution.save()
|
||||
execution.start()
|
||||
accounts = self.model.objects.filter(asset=asset)
|
||||
return self.get_paginated_response_from_queryset(accounts)
|
||||
|
|
|
@ -13,11 +13,3 @@ class AccountBasePlaybookManager(BasePlaybookManager):
|
|||
@property
|
||||
def platform_automation_methods(self):
|
||||
return platform_automation_methods
|
||||
|
||||
def gen_report(self):
|
||||
context = {
|
||||
'execution': self.execution,
|
||||
'summary': self.execution.summary,
|
||||
'result': self.execution.result
|
||||
}
|
||||
return render_to_string(self.template_path, context)
|
||||
|
|
|
@ -7,11 +7,15 @@ from django.utils import timezone
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
from xlsxwriter import Workbook
|
||||
|
||||
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy, ChangeSecretRecordStatusChoice
|
||||
from accounts.const import (
|
||||
AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy, ChangeSecretRecordStatusChoice
|
||||
)
|
||||
from accounts.models import ChangeSecretRecord, BaseAccountQuerySet
|
||||
from accounts.notifications import ChangeSecretExecutionTaskMsg, ChangeSecretFailedMsg
|
||||
from accounts.serializers import ChangeSecretRecordBackUpSerializer
|
||||
from assets.const import HostTypes
|
||||
from common.db.utils import safe_db_connection
|
||||
from common.decorators import bulk_create_decorator
|
||||
from common.utils import get_logger
|
||||
from common.utils.file import encrypt_and_compress_zip_file
|
||||
from common.utils.timezone import local_now_filename
|
||||
|
@ -26,7 +30,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.record_map = self.execution.snapshot.get('record_map', {})
|
||||
self.record_map = self.execution.snapshot.get('record_map', {}) # 这个是某个失败的记录重试
|
||||
self.secret_type = self.execution.snapshot.get('secret_type')
|
||||
self.secret_strategy = self.execution.snapshot.get(
|
||||
'secret_strategy', SecretStrategy.custom
|
||||
|
@ -36,6 +40,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
)
|
||||
self.account_ids = self.execution.snapshot['accounts']
|
||||
self.name_recorder_mapper = {} # 做个映射,方便后面处理
|
||||
self.pending_add_records = []
|
||||
|
||||
@classmethod
|
||||
def method_type(cls):
|
||||
|
@ -52,18 +57,6 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
kwargs['regexp'] = '.*{}$'.format(secret.split()[2].strip())
|
||||
return kwargs
|
||||
|
||||
def secret_generator(self, secret_type):
|
||||
return SecretGenerator(
|
||||
self.secret_strategy, secret_type,
|
||||
self.execution.snapshot.get('password_rules')
|
||||
)
|
||||
|
||||
def get_secret(self, secret_type):
|
||||
if self.secret_strategy == SecretStrategy.custom:
|
||||
return self.execution.snapshot['secret']
|
||||
else:
|
||||
return self.secret_generator(secret_type).get_secret()
|
||||
|
||||
def get_accounts(self, privilege_account) -> BaseAccountQuerySet | None:
|
||||
if not privilege_account:
|
||||
print('Not privilege account')
|
||||
|
@ -81,10 +74,65 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
)
|
||||
return accounts
|
||||
|
||||
def host_callback(
|
||||
self, host, asset=None, account=None,
|
||||
automation=None, path_dir=None, **kwargs
|
||||
):
|
||||
def gen_new_secret(self, account, path_dir):
|
||||
private_key_path = None
|
||||
if self.secret_type is None:
|
||||
new_secret = account.secret
|
||||
return new_secret, private_key_path
|
||||
|
||||
if self.secret_strategy == SecretStrategy.custom:
|
||||
new_secret = self.execution.snapshot['secret']
|
||||
else:
|
||||
generator = SecretGenerator(
|
||||
self.secret_strategy, self.secret_type,
|
||||
self.execution.snapshot.get('password_rules')
|
||||
)
|
||||
new_secret = generator.get_secret()
|
||||
|
||||
if account.secret_type == SecretType.SSH_KEY:
|
||||
private_key_path = self.generate_private_key_path(new_secret, path_dir)
|
||||
new_secret = self.generate_public_key(new_secret)
|
||||
return new_secret, private_key_path
|
||||
|
||||
def get_or_create_record(self, asset, account, new_secret, name):
|
||||
asset_account_id = f'{asset.id}-{account.id}'
|
||||
|
||||
if asset_account_id in self.record_map:
|
||||
record_id = self.record_map[asset_account_id]
|
||||
recorder = ChangeSecretRecord.objects.filter(id=record_id).first()
|
||||
else:
|
||||
recorder = self.create_record(asset, account, new_secret)
|
||||
|
||||
if recorder:
|
||||
self.name_recorder_mapper[name] = recorder
|
||||
|
||||
@bulk_create_decorator(ChangeSecretRecord)
|
||||
def create_record(self, asset, account, new_secret):
|
||||
recorder = ChangeSecretRecord(
|
||||
asset=asset, account=account, execution=self.execution,
|
||||
old_secret=account.secret, new_secret=new_secret,
|
||||
comment=f'{account.username}@{asset.address}'
|
||||
)
|
||||
return recorder
|
||||
|
||||
def gen_change_secret_inventory(self, host, account, new_secret, private_key_path, asset):
|
||||
h = deepcopy(host)
|
||||
secret_type = account.secret_type
|
||||
h['name'] += '(' + account.username + ')'
|
||||
h['ssh_params'].update(self.get_ssh_params(account, new_secret, secret_type))
|
||||
h['account'] = {
|
||||
'name': account.name,
|
||||
'username': account.username,
|
||||
'secret_type': secret_type,
|
||||
'secret': account.escape_jinja2_syntax(new_secret),
|
||||
'private_key_path': private_key_path,
|
||||
'become': account.get_ansible_become_auth(),
|
||||
}
|
||||
if asset.platform.type == 'oracle':
|
||||
h['account']['mode'] = 'sysdba' if account.privileged else None
|
||||
return h
|
||||
|
||||
def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs):
|
||||
host = super().host_callback(
|
||||
host, asset=asset, account=account, automation=automation,
|
||||
path_dir=path_dir, **kwargs
|
||||
|
@ -93,72 +141,29 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
return host
|
||||
|
||||
host['check_conn_after_change'] = self.execution.snapshot.get('check_conn_after_change', True)
|
||||
host['ssh_params'] = {}
|
||||
|
||||
accounts = self.get_accounts(account)
|
||||
error_msg = _("No pending accounts found")
|
||||
error_msg = _("! No pending accounts found")
|
||||
if not accounts:
|
||||
print(f'{asset}: {error_msg}')
|
||||
return []
|
||||
|
||||
records = []
|
||||
inventory_hosts = []
|
||||
if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY:
|
||||
print(f'Windows {asset} does not support ssh key push')
|
||||
return inventory_hosts
|
||||
|
||||
if asset.type == HostTypes.WINDOWS:
|
||||
accounts = accounts.filter(secret_type=SecretType.PASSWORD)
|
||||
|
||||
host['ssh_params'] = {}
|
||||
inventory_hosts = []
|
||||
if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY:
|
||||
print(f'! Windows {asset} does not support ssh key push')
|
||||
return inventory_hosts
|
||||
|
||||
for account in accounts:
|
||||
h = deepcopy(host)
|
||||
secret_type = account.secret_type
|
||||
h['name'] += '(' + account.username + ')'
|
||||
if self.secret_type is None:
|
||||
new_secret = account.secret
|
||||
else:
|
||||
new_secret = self.get_secret(secret_type)
|
||||
|
||||
if new_secret is None:
|
||||
print(f'new_secret is None, account: {account}')
|
||||
continue
|
||||
|
||||
asset_account_id = f'{asset.id}-{account.id}'
|
||||
if asset_account_id not in self.record_map:
|
||||
recorder = ChangeSecretRecord(
|
||||
asset=asset, account=account, execution=self.execution,
|
||||
old_secret=account.secret, new_secret=new_secret,
|
||||
comment=f'{account.username}@{asset.address}'
|
||||
)
|
||||
records.append(recorder)
|
||||
else:
|
||||
record_id = self.record_map[asset_account_id]
|
||||
try:
|
||||
recorder = ChangeSecretRecord.objects.get(id=record_id)
|
||||
except ChangeSecretRecord.DoesNotExist:
|
||||
print(f"Record {record_id} not found")
|
||||
continue
|
||||
|
||||
self.name_recorder_mapper[h['name']] = recorder
|
||||
|
||||
private_key_path = None
|
||||
if secret_type == SecretType.SSH_KEY:
|
||||
private_key_path = self.generate_private_key_path(new_secret, path_dir)
|
||||
new_secret = self.generate_public_key(new_secret)
|
||||
|
||||
h['ssh_params'].update(self.get_ssh_params(account, new_secret, secret_type))
|
||||
h['account'] = {
|
||||
'name': account.name,
|
||||
'username': account.username,
|
||||
'secret_type': secret_type,
|
||||
'secret': account.escape_jinja2_syntax(new_secret),
|
||||
'private_key_path': private_key_path,
|
||||
'become': account.get_ansible_become_auth(),
|
||||
}
|
||||
if asset.platform.type == 'oracle':
|
||||
h['account']['mode'] = 'sysdba' if account.privileged else None
|
||||
new_secret, private_key_path = self.gen_new_secret(account, path_dir)
|
||||
h = self.gen_change_secret_inventory(host, account, new_secret, private_key_path, asset)
|
||||
self.get_or_create_record(asset, account, new_secret, h['name'])
|
||||
inventory_hosts.append(h)
|
||||
ChangeSecretRecord.objects.bulk_create(records)
|
||||
|
||||
self.create_record.finish()
|
||||
return inventory_hosts
|
||||
|
||||
def on_host_success(self, host, result):
|
||||
|
@ -172,24 +177,13 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
if not account:
|
||||
print("Account not found, deleted ?")
|
||||
return
|
||||
|
||||
account.secret = recorder.new_secret
|
||||
account.date_updated = timezone.now()
|
||||
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
recorder.save()
|
||||
account.save(update_fields=['secret', 'version', 'date_updated'])
|
||||
break
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
self.on_host_error(host, str(e), result)
|
||||
else:
|
||||
print(f'retry {retry_count} times for {host} recorder save error: {e}')
|
||||
time.sleep(1)
|
||||
with safe_db_connection():
|
||||
recorder.save(update_fields=['status', 'date_finished'])
|
||||
account.save(update_fields=['secret', 'version', 'date_updated'])
|
||||
|
||||
def on_host_error(self, host, error, result):
|
||||
recorder = self.name_recorder_mapper.get(host)
|
||||
|
@ -222,21 +216,24 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
else:
|
||||
failed += 1
|
||||
total += 1
|
||||
|
||||
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
|
||||
return summary
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
if self.secret_type and not self.check_secret():
|
||||
self.execution.status = 'success'
|
||||
self.execution.date_finished = timezone.now()
|
||||
self.execution.save()
|
||||
return
|
||||
super().run(*args, **kwargs)
|
||||
def print_summary(self):
|
||||
recorders = list(self.name_recorder_mapper.values())
|
||||
summary = self.get_summary(recorders)
|
||||
print(summary, end='')
|
||||
print('\n\n' + '-' * 80)
|
||||
plan_execution_end = _('Plan execution end')
|
||||
print('{} {}\n'.format(plan_execution_end, local_now_filename()))
|
||||
time_cost = _('Time cost')
|
||||
print('{}: {}s'.format(time_cost, self.duration))
|
||||
print(summary)
|
||||
|
||||
def send_report_if_need(self, *args, **kwargs):
|
||||
if self.secret_type and not self.check_secret():
|
||||
return
|
||||
|
||||
recorders = list(self.name_recorder_mapper.values())
|
||||
if self.record_map:
|
||||
return
|
||||
|
||||
|
@ -262,6 +259,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
if not recorders:
|
||||
return
|
||||
|
||||
summary = self.get_summary(recorders)
|
||||
self.send_recorder_mail(recipients, recorders, summary)
|
||||
|
||||
def send_recorder_mail(self, recipients, recorders, summary):
|
||||
|
|
|
@ -88,7 +88,7 @@ def check_account_secrets(accounts, assets):
|
|||
|
||||
|
||||
class CheckAccountManager(BaseManager):
|
||||
batch_size=100
|
||||
batch_size = 100
|
||||
|
||||
def __init__(self, execution):
|
||||
super().__init__(execution)
|
||||
|
@ -117,8 +117,14 @@ class CheckAccountManager(BaseManager):
|
|||
for k, v in result.items():
|
||||
self.result[k].extend(v)
|
||||
|
||||
def get_report_subject(self):
|
||||
return "Check account report of %s" % self.execution.id
|
||||
|
||||
def get_report_template(self):
|
||||
return 'accounts/check_account_report.html'
|
||||
|
||||
def print_summary(self):
|
||||
tmpl = "\n---\nSummary: \nok: %s, weak password: %s, no secret: %s, using time: %ss" % (
|
||||
self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.timedelta)
|
||||
self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.duration)
|
||||
)
|
||||
print(tmpl)
|
||||
|
|
|
@ -6,6 +6,7 @@ from accounts.const import AutomationTypes
|
|||
from accounts.models import GatheredAccount, Account, AccountRisk
|
||||
from assets.models import Asset
|
||||
from common.const import ConfirmOrIgnore
|
||||
from common.decorators import bulk_create_decorator, bulk_update_decorator
|
||||
from common.utils import get_logger
|
||||
from common.utils.strings import get_text_diff
|
||||
from orgs.utils import tmp_to_org
|
||||
|
@ -62,35 +63,18 @@ class AnalyseAccountRisk:
|
|||
if not diff:
|
||||
return
|
||||
|
||||
risks = []
|
||||
for k, v in diff.items():
|
||||
self.pending_add_risks.append(dict(
|
||||
risks.append(dict(
|
||||
asset=ori_account.asset, username=ori_account.username,
|
||||
risk=k+'_changed', detail={'diff': v}
|
||||
))
|
||||
|
||||
def perform_save_risks(self, risks):
|
||||
# 提前取出来,避免每次都查数据库
|
||||
assets = {r['asset'] for r in risks}
|
||||
assets_risks = AccountRisk.objects.filter(asset__in=assets)
|
||||
assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks}
|
||||
|
||||
for d in risks:
|
||||
detail = d.pop('detail', {})
|
||||
detail['datetime'] = self.now.isoformat()
|
||||
key = f"{d['asset'].id}_{d['username']}_{d['risk']}"
|
||||
found = assets_risks.get(key)
|
||||
|
||||
if not found:
|
||||
r = AccountRisk(**d, details=[detail])
|
||||
r.save()
|
||||
continue
|
||||
|
||||
found.details.append(detail)
|
||||
found.save(update_fields=['details'])
|
||||
self.save_or_update_risks(risks)
|
||||
|
||||
def _analyse_datetime_changed(self, ori_account, d, asset, username):
|
||||
basic = {'asset': asset, 'username': username}
|
||||
|
||||
risks = []
|
||||
for item in self.datetime_check_items:
|
||||
field = item['field']
|
||||
risk = item['risk']
|
||||
|
@ -105,41 +89,61 @@ class AnalyseAccountRisk:
|
|||
continue
|
||||
|
||||
if date and date < timezone.now() - delta:
|
||||
self.pending_add_risks.append(
|
||||
risks.append(
|
||||
dict(**basic, risk=risk, detail={'date': date.isoformat()})
|
||||
)
|
||||
|
||||
def batch_analyse_risk(self, asset, ori_account, d, batch_size=20):
|
||||
if not self.check_risk:
|
||||
return
|
||||
self.save_or_update_risks(risks)
|
||||
|
||||
if asset is None:
|
||||
if self.pending_add_risks:
|
||||
self.perform_save_risks(self.pending_add_risks)
|
||||
self.pending_add_risks = []
|
||||
def save_or_update_risks(self, risks):
|
||||
# 提前取出来,避免每次都查数据库
|
||||
assets = {r['asset'] for r in risks}
|
||||
assets_risks = AccountRisk.objects.filter(asset__in=assets)
|
||||
assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks}
|
||||
|
||||
for d in risks:
|
||||
detail = d.pop('detail', {})
|
||||
detail['datetime'] = self.now.isoformat()
|
||||
key = f"{d['asset'].id}_{d['username']}_{d['risk']}"
|
||||
found = assets_risks.get(key)
|
||||
|
||||
if not found:
|
||||
self._create_risk(dict(**d, details=[detail]))
|
||||
continue
|
||||
|
||||
found.details.append(detail)
|
||||
self._update_risk(found)
|
||||
|
||||
@bulk_create_decorator(AccountRisk)
|
||||
def _create_risk(self, data):
|
||||
return AccountRisk(**data)
|
||||
|
||||
@bulk_update_decorator(AccountRisk, update_fields=['details'])
|
||||
def _update_risk(self, account):
|
||||
return account
|
||||
|
||||
def finish(self):
|
||||
self._create_risk.finish()
|
||||
self._update_risk.finish()
|
||||
|
||||
def analyse_risk(self, asset, ori_account, d):
|
||||
if not self.check_risk:
|
||||
return
|
||||
|
||||
basic = {'asset': asset, 'username': d['username']}
|
||||
if ori_account:
|
||||
self._analyse_item_changed(ori_account, d)
|
||||
else:
|
||||
self.pending_add_risks.append(
|
||||
dict(**basic, risk='ghost')
|
||||
)
|
||||
self._create_risk(dict(**basic, risk='new_account'))
|
||||
|
||||
self._analyse_datetime_changed(ori_account, d, asset, d['username'])
|
||||
|
||||
if len(self.pending_add_risks) > batch_size:
|
||||
self.batch_analyse_risk(None, None, {})
|
||||
|
||||
|
||||
class GatherAccountsManager(AccountBasePlaybookManager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.host_asset_mapper = {}
|
||||
self.asset_account_info = {}
|
||||
self.pending_add_accounts = []
|
||||
self.pending_update_accounts = []
|
||||
self.asset_usernames_mapper = defaultdict(set)
|
||||
self.ori_asset_usernames = defaultdict(set)
|
||||
self.ori_gathered_usernames = defaultdict(set)
|
||||
|
@ -204,7 +208,7 @@ class GatherAccountsManager(AccountBasePlaybookManager):
|
|||
ga_accounts = GatheredAccount.objects.filter(asset__in=assets)
|
||||
for account in ga_accounts:
|
||||
self.ori_gathered_usernames[account.asset].add(account.username)
|
||||
key = '{}_{}'.format(account.asset.id, account.username)
|
||||
key = '{}_{}'.format(account.asset_id, account.username)
|
||||
self.ori_gathered_accounts_mapper[key] = account
|
||||
|
||||
def update_gather_accounts_status(self, asset):
|
||||
|
@ -258,38 +262,25 @@ class GatherAccountsManager(AccountBasePlaybookManager):
|
|||
# 资产上没有的,标识为为存在
|
||||
queryset.exclude(username__in=ori_users).filter(present=False).update(present=True)
|
||||
|
||||
def batch_create_gathered_account(self, d, batch_size=20):
|
||||
if d is None:
|
||||
if self.pending_add_accounts:
|
||||
GatheredAccount.objects.bulk_create(self.pending_add_accounts, ignore_conflicts=True)
|
||||
self.pending_add_accounts = []
|
||||
return
|
||||
|
||||
@bulk_create_decorator(GatheredAccount)
|
||||
def create_gathered_account(self, d):
|
||||
gathered_account = GatheredAccount()
|
||||
for k, v in d.items():
|
||||
setattr(gathered_account, k, v)
|
||||
self.pending_add_accounts.append(gathered_account)
|
||||
|
||||
if len(self.pending_add_accounts) > batch_size:
|
||||
self.batch_create_gathered_account(None)
|
||||
|
||||
def batch_update_gathered_account(self, ori_account, d, batch_size=20):
|
||||
if not ori_account or d is None:
|
||||
if self.pending_update_accounts:
|
||||
GatheredAccount.objects.bulk_update(self.pending_update_accounts, [*diff_items])
|
||||
self.pending_update_accounts = []
|
||||
return
|
||||
return gathered_account
|
||||
|
||||
@bulk_update_decorator(GatheredAccount, update_fields=diff_items)
|
||||
def update_gathered_account(self, ori_account, d):
|
||||
diff = get_items_diff(ori_account, d)
|
||||
if diff:
|
||||
for k in diff:
|
||||
setattr(ori_account, k, d[k])
|
||||
self.pending_update_accounts.append(ori_account)
|
||||
if not diff:
|
||||
return
|
||||
for k in diff:
|
||||
setattr(ori_account, k, d[k])
|
||||
return ori_account
|
||||
|
||||
if len(self.pending_update_accounts) > batch_size:
|
||||
self.batch_update_gathered_account(None, None)
|
||||
|
||||
def do_run(self):
|
||||
def do_run(self, *args, **kwargs):
|
||||
super().do_run(*args, **kwargs)
|
||||
self.prefetch_origin_account_usernames()
|
||||
risk_analyser = AnalyseAccountRisk(self.check_risk)
|
||||
|
||||
for asset, accounts_data in self.asset_account_info.items():
|
||||
|
@ -300,21 +291,20 @@ class GatherAccountsManager(AccountBasePlaybookManager):
|
|||
ori_account = self.ori_gathered_accounts_mapper.get('{}_{}'.format(asset.id, username))
|
||||
|
||||
if not ori_account:
|
||||
self.batch_create_gathered_account(d)
|
||||
self.create_gathered_account(d)
|
||||
else:
|
||||
self.batch_update_gathered_account(ori_account, d)
|
||||
risk_analyser.batch_analyse_risk(asset, ori_account, d)
|
||||
self.update_gathered_account(ori_account, d)
|
||||
risk_analyser.analyse_risk(asset, ori_account, d)
|
||||
|
||||
self.update_gather_accounts_status(asset)
|
||||
GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account)
|
||||
|
||||
self.batch_create_gathered_account(None)
|
||||
self.batch_update_gathered_account(None, None)
|
||||
risk_analyser.batch_analyse_risk(None, None, {})
|
||||
self.create_gathered_account.finish()
|
||||
self.update_gathered_account.finish()
|
||||
risk_analyser.finish()
|
||||
|
||||
def before_run(self):
|
||||
super().before_run()
|
||||
self.prefetch_origin_account_usernames()
|
||||
def send_report_if_need(self):
|
||||
pass
|
||||
|
||||
def generate_send_users_and_change_info(self):
|
||||
recipients = self.execution.recipients
|
||||
|
|
|
@ -98,13 +98,37 @@ class BaseManager:
|
|||
def get_assets_group_by_platform(self):
|
||||
return self.execution.all_assets_group_by_platform()
|
||||
|
||||
def before_run(self):
|
||||
def pre_run(self):
|
||||
self.execution.date_start = timezone.now()
|
||||
self.execution.save(update_fields=['date_start'])
|
||||
|
||||
def update_execution(self):
|
||||
self.duration = int(time.time() - self.time_start)
|
||||
self.execution.date_finished = timezone.now()
|
||||
self.execution.duration = self.duration
|
||||
self.execution.summary = self.summary
|
||||
self.execution.result = self.result
|
||||
self.execution.status = 'success'
|
||||
|
||||
with safe_db_connection():
|
||||
self.execution.save()
|
||||
|
||||
def print_summary(self):
|
||||
pass
|
||||
|
||||
def get_report_template(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_report_subject(self):
|
||||
return f'Automation {self.execution.id} finished'
|
||||
|
||||
def get_report_context(self):
|
||||
return {
|
||||
'execution': self.execution,
|
||||
'summary': self.execution.summary,
|
||||
'result': self.execution.result
|
||||
}
|
||||
|
||||
def send_report_if_need(self):
|
||||
recipients = self.execution.recipients
|
||||
if not recipients:
|
||||
|
@ -113,45 +137,24 @@ class BaseManager:
|
|||
report = self.gen_report()
|
||||
report = transform(report)
|
||||
subject = self.get_report_subject()
|
||||
print("Send resport to: {}".format([str(r) for r in recipients]))
|
||||
emails = [r.email for r in recipients if r.email]
|
||||
send_mail_async(subject, report, emails, html_message=report)
|
||||
|
||||
def update_execution(self):
|
||||
self.duration = int(time.time() - self.time_start)
|
||||
self.execution.date_finished = timezone.now()
|
||||
self.execution.duration = self.duration
|
||||
self.execution.summary = self.summary
|
||||
self.execution.result = self.result
|
||||
|
||||
with safe_db_connection():
|
||||
self.execution.save(update_fields=['date_finished', 'duration', 'summary', 'result'])
|
||||
|
||||
def print_summary(self):
|
||||
pass
|
||||
|
||||
def get_template_path(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def gen_report(self):
|
||||
template_path = self.get_template_path()
|
||||
context = {
|
||||
'execution': self.execution,
|
||||
'summary': self.execution.summary,
|
||||
'result': self.execution.result
|
||||
}
|
||||
template_path = self.get_report_template()
|
||||
context = self.get_report_context()
|
||||
data = render_to_string(template_path, context)
|
||||
return data
|
||||
|
||||
def after_run(self):
|
||||
def post_run(self):
|
||||
self.update_execution()
|
||||
self.print_summary()
|
||||
self.send_report_if_need()
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.before_run()
|
||||
self.pre_run()
|
||||
self.do_run(*args, **kwargs)
|
||||
self.after_run()
|
||||
self.post_run()
|
||||
|
||||
def do_run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
@ -374,15 +377,19 @@ class BasePlaybookManager(BaseManager):
|
|||
if settings.DEBUG_DEV:
|
||||
print('host error: {} -> {}'.format(host, error))
|
||||
|
||||
def _on_host_success(self, host, result, error, detail):
|
||||
self.on_host_success(host, result)
|
||||
def _on_host_success(self, host, result, hosts):
|
||||
self.on_host_success(host, result.get("ok", ''))
|
||||
|
||||
def _on_host_error(self, host, result, error, detail):
|
||||
def _on_host_error(self, host, result, hosts):
|
||||
error = hosts.get(host, '')
|
||||
detail = result.get('failures', '') or result.get('dark', '')
|
||||
self.on_host_error(host, error, detail)
|
||||
|
||||
def on_runner_success(self, runner, cb):
|
||||
summary = cb.summary
|
||||
for state, hosts in summary.items():
|
||||
# 错误行为为,host 是 dict, ok 时是 list
|
||||
|
||||
if state == 'ok':
|
||||
handler = self._on_host_success
|
||||
elif state == 'skipped':
|
||||
|
@ -392,9 +399,7 @@ class BasePlaybookManager(BaseManager):
|
|||
|
||||
for host in hosts:
|
||||
result = cb.host_results.get(host)
|
||||
error = hosts.get(host, '')
|
||||
detail = result.get('failures', '') or result.get('dark', '')
|
||||
handler(host, result, error, detail)
|
||||
handler(host, result, hosts)
|
||||
|
||||
def on_runner_failed(self, runner, e):
|
||||
print("Runner failed: {} {}".format(e, self))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
from django.db import connections, transaction
|
||||
from django.db import connections, transaction, connection
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from common.utils import get_logger, signer, crypto
|
||||
|
@ -13,7 +13,7 @@ def get_object_if_need(model, pk):
|
|||
try:
|
||||
return model.objects.get(id=pk)
|
||||
except model.DoesNotExist as e:
|
||||
logger.error(f'DoesNotExist: <{model.__name__}:{pk}> not exist')
|
||||
logger.error(f"DoesNotExist: <{model.__name__}:{pk}> not exist")
|
||||
raise e
|
||||
return pk
|
||||
|
||||
|
@ -26,8 +26,8 @@ def get_objects_if_need(model, pks):
|
|||
if len(objs) != len(pks):
|
||||
pks = set(pks)
|
||||
exists_pks = {o.id for o in objs}
|
||||
not_found_pks = ','.join(pks - exists_pks)
|
||||
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>')
|
||||
not_found_pks = ",".join(pks - exists_pks)
|
||||
logger.error(f"DoesNotExist: <{model.__name__}: {not_found_pks}>")
|
||||
return objs
|
||||
return pks
|
||||
|
||||
|
@ -41,7 +41,7 @@ def get_objects(model, pks):
|
|||
pks = set(pks)
|
||||
exists_pks = {o.id for o in objs}
|
||||
not_found_pks = pks - exists_pks
|
||||
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>')
|
||||
logger.error(f"DoesNotExist: <{model.__name__}: {not_found_pks}>")
|
||||
return objs
|
||||
|
||||
|
||||
|
@ -53,13 +53,17 @@ def close_old_connections():
|
|||
|
||||
@contextmanager
|
||||
def safe_db_connection():
|
||||
close_old_connections()
|
||||
yield
|
||||
close_old_connections()
|
||||
try:
|
||||
close_old_connections() # 确保旧连接关闭
|
||||
if connection.connection: # 如果连接已关闭,重新连接
|
||||
connection.connect()
|
||||
yield
|
||||
finally:
|
||||
close_old_connections() # 确保最终关闭连接
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_db_connection(alias='default'):
|
||||
def open_db_connection(alias="default"):
|
||||
connection = transaction.get_connection(alias)
|
||||
try:
|
||||
connection.connect()
|
||||
|
|
|
@ -11,8 +11,8 @@ from functools import wraps
|
|||
|
||||
from django.db import transaction
|
||||
|
||||
from .utils import logger
|
||||
from .db.utils import open_db_connection
|
||||
from .utils import logger
|
||||
|
||||
|
||||
def on_transaction_commit(func):
|
||||
|
@ -294,3 +294,104 @@ def cached_method(ttl=20):
|
|||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def bulk_create_decorator(instance_model, batch_size=50):
|
||||
"""
|
||||
装饰器,用于将实例批量保存,并提供 `commit` 方法提交剩余的实例。
|
||||
|
||||
:param instance_model: Django模型类,用于调用 bulk_create 方法。
|
||||
:param batch_size: 批量保存的阈值,默认50。
|
||||
"""
|
||||
def decorator(func):
|
||||
cache = [] # 缓存实例的列表
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal cache
|
||||
|
||||
# 调用被装饰的函数,生成一个实例
|
||||
instance = func(*args, **kwargs)
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
# 添加实例到缓存
|
||||
cache.append(instance)
|
||||
print(f"Instance added to cache. Cache size: {len(cache)}")
|
||||
|
||||
# 如果缓存大小达到批量保存阈值,执行保存
|
||||
if len(cache) >= batch_size:
|
||||
print(f"Batch size reached. Saving {len(cache)} instances...")
|
||||
instance_model.objects.bulk_create(cache)
|
||||
cache.clear()
|
||||
|
||||
return instance
|
||||
|
||||
# 提交剩余实例的方法
|
||||
def commit():
|
||||
nonlocal cache
|
||||
if cache:
|
||||
print(f"Committing remaining {len(cache)} instances...")
|
||||
instance_model.objects.bulk_create(cache)
|
||||
cache.clear()
|
||||
|
||||
wrapper.finish = commit
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def bulk_update_decorator(instance_model, batch_size=50, update_fields=None):
|
||||
"""
|
||||
装饰器,用于批量更新实例,并提供 `commit` 方法提交剩余的更新。
|
||||
|
||||
:param instance_model: Django模型类,用于调用 bulk_update 方法。
|
||||
:param batch_size: 批量更新的阈值,默认50。
|
||||
:param update_fields: 指定要更新的字段列表,默认为 None,表示更新所有字段。
|
||||
"""
|
||||
def decorator(func):
|
||||
cache = [] # 缓存待更新实例的列表
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal cache
|
||||
|
||||
# 调用被装饰的函数,获取一个需要更新的实例
|
||||
instance = func(*args, **kwargs)
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
# 添加实例到缓存
|
||||
cache.append(instance)
|
||||
print(f"Instance added to update cache. Cache size: {len(cache)}")
|
||||
|
||||
# 如果缓存大小达到批量更新阈值,执行更新
|
||||
if len(cache) >= batch_size:
|
||||
print(f"Batch size reached. Updating {len(cache)} instances...")
|
||||
instance_model.objects.bulk_update(cache, update_fields)
|
||||
cache.clear()
|
||||
|
||||
return instance
|
||||
|
||||
# 提交剩余更新的方法
|
||||
def commit():
|
||||
nonlocal cache
|
||||
if cache:
|
||||
print(f"Committing remaining {len(cache)} instances..., {update_fields}")
|
||||
# with transaction.atomic():
|
||||
# for c in cache:
|
||||
# o = instance_model.objects.get(id=str(c.id))
|
||||
# print("Origin: ", o.id, o.sudoers)
|
||||
# o.sudoers = c.sudoers
|
||||
# o.save()
|
||||
# print("New: ", c.id, c.sudoers)
|
||||
instance_model.objects.bulk_update(cache, update_fields)
|
||||
# print("Committing remaining instances... done, ", cache[0].sudoers, cache[0].id, instance_model)
|
||||
# print(instance_model.objects.get(id=str(cache[0].id)).sudoers)
|
||||
cache.clear()
|
||||
|
||||
# 将 commit 方法绑定到装饰后的函数
|
||||
wrapper.finish = commit
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
Loading…
Reference in New Issue