perf: update bulk create decorator

pull/14534/head
ibuler 2024-11-20 16:27:41 +08:00
parent 886875d628
commit a3b3254c35
6 changed files with 68 additions and 65 deletions

View File

@ -81,10 +81,10 @@ class GatheredAccountViewSet(OrgBulkModelViewSet):
'nodes': [],
'type': 'gather_accounts',
'is_sync_account': False,
'check_risk': True,
'name': 'Adhoc gather accounts: {}'.format(asset_id),
}
with transaction.atomic():
execution.save()
execution.save()
execution.start()
accounts = self.model.objects.filter(asset=asset).prefetch_related('asset', 'asset__platform')
return self.get_paginated_response_from_queryset(accounts)

View File

@ -163,7 +163,6 @@ class ChangeSecretManager(AccountBasePlaybookManager):
self.get_or_create_record(asset, account, new_secret, h['name'])
inventory_hosts.append(h)
self.create_record.finish()
return inventory_hosts
def on_host_success(self, host, result):

View File

@ -1,3 +1,4 @@
import time
from collections import defaultdict
from django.utils import timezone
@ -122,10 +123,6 @@ class AnalyseAccountRisk:
def _update_risk(self, account):
return account
def finish(self):
self._create_risk.finish()
self._update_risk.finish()
def analyse_risk(self, asset, ori_account, d):
if not self.check_risk:
return
@ -134,7 +131,7 @@ class AnalyseAccountRisk:
if ori_account:
self._analyse_item_changed(ori_account, d)
else:
self._create_risk(dict(**basic, risk='new_account'))
self._create_risk(dict(**basic, risk='ghost', details=[{'datetime': self.now.isoformat()}]))
self._analyse_datetime_changed(ori_account, d, asset, d['username'])
@ -298,10 +295,8 @@ class GatherAccountsManager(AccountBasePlaybookManager):
self.update_gather_accounts_status(asset)
GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account)
self.create_gathered_account.finish()
self.update_gathered_account.finish()
risk_analyser.finish()
# 因为有 bulk create, bulk update, 所以这里需要 sleep 一下,等待数据同步
time.sleep(0.5)
def send_report_if_need(self):
pass

View File

@ -31,6 +31,10 @@ class AccountRiskSerializer(serializers.ModelSerializer):
'date_created', 'details',
]
@classmethod
def setup_eager_loading(cls, queryset):
return queryset.select_related('asset')
class RiskSummarySerializer(serializers.Serializer):
risk = serializers.CharField(max_length=128)

View File

@ -55,15 +55,18 @@ def close_old_connections():
def safe_db_connection():
in_atomic_block = connection.in_atomic_block # 当前是否处于事务中
autocommit = transaction.get_autocommit() # 是否启用了自动提交
created = False
try:
if not connection.is_usable():
connection.close()
connection.connect()
created = True
yield
finally:
# 如果不是事务中API 请求中可能需要提交事务),则关闭连接
if not in_atomic_block and autocommit:
if created and not in_atomic_block and autocommit:
print("close connection in safe_db_connection")
close_old_connections()

View File

@ -11,7 +11,7 @@ from functools import wraps
from django.db import transaction
from .db.utils import open_db_connection
from .db.utils import open_db_connection, safe_db_connection
from .utils import logger
@ -296,14 +296,48 @@ def cached_method(ttl=20):
return decorator
def bulk_create_decorator(instance_model, batch_size=50, ignore_conflict=True):
def bulk_handle(handler, batch_size=50, timeout=0.5):
def decorator(func):
from orgs.utils import get_current_org_id
cache = [] # 缓存实例的列表
lock = threading.Lock() # 用于线程安全
timer = [None] # 定时器对象,列表存储以便重置
org_id = None
def reset_timer():
"""重置定时器"""
if timer[0] is not None:
timer[0].cancel()
timer[0] = threading.Timer(timeout, handle_remaining)
timer[0].start()
def handle_it():
from orgs.utils import tmp_to_org
with lock:
if not cache:
return
with tmp_to_org(org_id):
with safe_db_connection():
handler(cache)
cache.clear()
def handle_on_org_changed():
nonlocal org_id
if org_id is None:
org_id = get_current_org_id()
else:
c_org_id = get_current_org_id()
if org_id != c_org_id:
handle_it()
org_id = c_org_id
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal cache
handle_on_org_changed()
# 调用被装饰的函数,生成一个实例
instance = func(*args, **kwargs)
if instance is None:
@ -315,68 +349,36 @@ def bulk_create_decorator(instance_model, batch_size=50, ignore_conflict=True):
# 如果缓存大小达到批量保存阈值,执行保存
if len(cache) >= batch_size:
print(f"Batch size reached. Saving {len(cache)} instances...")
instance_model.objects.bulk_create(cache, ignore_conflict=ignore_conflict)
cache.clear()
handle_it()
reset_timer()
return instance
# 提交剩余实例的方法
def commit():
nonlocal cache
if cache:
print(f"Committing remaining {len(cache)} instances...")
instance_model.objects.bulk_create(cache)
cache.clear()
def handle_remaining():
if not cache:
return
print("Timer expired. Saving remaining instances.")
from orgs.utils import tmp_to_org
with tmp_to_org(org_id):
handle_it()
wrapper.finish = commit
wrapper.finish = handle_remaining
return wrapper
return decorator
def bulk_update_decorator(instance_model, batch_size=50, update_fields=None):
"""
装饰器用于批量更新实例并提供 `commit` 方法提交剩余的更新
def bulk_create_decorator(instance_model, batch_size=50, ignore_conflicts=True, timeout=0.3):
def handle(cache):
instance_model.objects.bulk_create(cache, ignore_conflicts=ignore_conflicts)
:param instance_model: Django模型类用于调用 bulk_update 方法
:param batch_size: 批量更新的阈值默认50
:param update_fields: 指定要更新的字段列表默认为 None表示更新所有字段
"""
def decorator(func):
cache = [] # 缓存待更新实例的列表
return bulk_handle(handle, batch_size, timeout)
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal cache
# 调用被装饰的函数,获取一个需要更新的实例
instance = func(*args, **kwargs)
if instance is None:
return None
def bulk_update_decorator(instance_model, batch_size=50, update_fields=None, timeout=0.3):
def handle(cache):
instance_model.objects.bulk_update(cache, update_fields)
# 添加实例到缓存
cache.append(instance)
print(f"Instance added to update cache. Cache size: {len(cache)}")
return bulk_handle(handle, batch_size, timeout)
# 如果缓存大小达到批量更新阈值,执行更新
if len(cache) >= batch_size:
print(f"Batch size reached. Updating {len(cache)} instances...")
instance_model.objects.bulk_update(cache, update_fields)
cache.clear()
return instance
# 提交剩余更新的方法
def commit():
nonlocal cache
if cache:
print(f"Committing remaining {len(cache)} instances..., {update_fields}")
instance_model.objects.bulk_update(cache, update_fields)
cache.clear()
# 将 commit 方法绑定到装饰后的函数
wrapper.finish = commit
return wrapper
return decorator