mirror of https://github.com/jumpserver/jumpserver
perf: update bulk create decorator
parent
886875d628
commit
a3b3254c35
|
@ -81,10 +81,10 @@ class GatheredAccountViewSet(OrgBulkModelViewSet):
|
|||
'nodes': [],
|
||||
'type': 'gather_accounts',
|
||||
'is_sync_account': False,
|
||||
'check_risk': True,
|
||||
'name': 'Adhoc gather accounts: {}'.format(asset_id),
|
||||
}
|
||||
with transaction.atomic():
|
||||
execution.save()
|
||||
execution.save()
|
||||
execution.start()
|
||||
accounts = self.model.objects.filter(asset=asset).prefetch_related('asset', 'asset__platform')
|
||||
return self.get_paginated_response_from_queryset(accounts)
|
||||
|
|
|
@ -163,7 +163,6 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
self.get_or_create_record(asset, account, new_secret, h['name'])
|
||||
inventory_hosts.append(h)
|
||||
|
||||
self.create_record.finish()
|
||||
return inventory_hosts
|
||||
|
||||
def on_host_success(self, host, result):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from django.utils import timezone
|
||||
|
@ -122,10 +123,6 @@ class AnalyseAccountRisk:
|
|||
def _update_risk(self, account):
|
||||
return account
|
||||
|
||||
def finish(self):
|
||||
self._create_risk.finish()
|
||||
self._update_risk.finish()
|
||||
|
||||
def analyse_risk(self, asset, ori_account, d):
|
||||
if not self.check_risk:
|
||||
return
|
||||
|
@ -134,7 +131,7 @@ class AnalyseAccountRisk:
|
|||
if ori_account:
|
||||
self._analyse_item_changed(ori_account, d)
|
||||
else:
|
||||
self._create_risk(dict(**basic, risk='new_account'))
|
||||
self._create_risk(dict(**basic, risk='ghost', details=[{'datetime': self.now.isoformat()}]))
|
||||
|
||||
self._analyse_datetime_changed(ori_account, d, asset, d['username'])
|
||||
|
||||
|
@ -298,10 +295,8 @@ class GatherAccountsManager(AccountBasePlaybookManager):
|
|||
|
||||
self.update_gather_accounts_status(asset)
|
||||
GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account)
|
||||
|
||||
self.create_gathered_account.finish()
|
||||
self.update_gathered_account.finish()
|
||||
risk_analyser.finish()
|
||||
# 因为有 bulk create, bulk update, 所以这里需要 sleep 一下,等待数据同步
|
||||
time.sleep(0.5)
|
||||
|
||||
def send_report_if_need(self):
|
||||
pass
|
||||
|
|
|
@ -31,6 +31,10 @@ class AccountRiskSerializer(serializers.ModelSerializer):
|
|||
'date_created', 'details',
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setup_eager_loading(cls, queryset):
|
||||
return queryset.select_related('asset')
|
||||
|
||||
|
||||
class RiskSummarySerializer(serializers.Serializer):
|
||||
risk = serializers.CharField(max_length=128)
|
||||
|
|
|
@ -55,15 +55,18 @@ def close_old_connections():
|
|||
def safe_db_connection():
|
||||
in_atomic_block = connection.in_atomic_block # 当前是否处于事务中
|
||||
autocommit = transaction.get_autocommit() # 是否启用了自动提交
|
||||
created = False
|
||||
|
||||
try:
|
||||
if not connection.is_usable():
|
||||
connection.close()
|
||||
connection.connect()
|
||||
created = True
|
||||
yield
|
||||
finally:
|
||||
# 如果不是事务中(API 请求中可能需要提交事务),则关闭连接
|
||||
if not in_atomic_block and autocommit:
|
||||
if created and not in_atomic_block and autocommit:
|
||||
print("close connection in safe_db_connection")
|
||||
close_old_connections()
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from functools import wraps
|
|||
|
||||
from django.db import transaction
|
||||
|
||||
from .db.utils import open_db_connection
|
||||
from .db.utils import open_db_connection, safe_db_connection
|
||||
from .utils import logger
|
||||
|
||||
|
||||
|
@ -296,14 +296,48 @@ def cached_method(ttl=20):
|
|||
return decorator
|
||||
|
||||
|
||||
def bulk_create_decorator(instance_model, batch_size=50, ignore_conflict=True):
|
||||
def bulk_handle(handler, batch_size=50, timeout=0.5):
|
||||
def decorator(func):
|
||||
from orgs.utils import get_current_org_id
|
||||
|
||||
cache = [] # 缓存实例的列表
|
||||
lock = threading.Lock() # 用于线程安全
|
||||
timer = [None] # 定时器对象,列表存储以便重置
|
||||
org_id = None
|
||||
|
||||
def reset_timer():
|
||||
"""重置定时器"""
|
||||
if timer[0] is not None:
|
||||
timer[0].cancel()
|
||||
timer[0] = threading.Timer(timeout, handle_remaining)
|
||||
timer[0].start()
|
||||
|
||||
def handle_it():
|
||||
from orgs.utils import tmp_to_org
|
||||
with lock:
|
||||
if not cache:
|
||||
return
|
||||
with tmp_to_org(org_id):
|
||||
with safe_db_connection():
|
||||
handler(cache)
|
||||
cache.clear()
|
||||
|
||||
def handle_on_org_changed():
|
||||
nonlocal org_id
|
||||
if org_id is None:
|
||||
org_id = get_current_org_id()
|
||||
else:
|
||||
c_org_id = get_current_org_id()
|
||||
if org_id != c_org_id:
|
||||
handle_it()
|
||||
org_id = c_org_id
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal cache
|
||||
|
||||
handle_on_org_changed()
|
||||
|
||||
# 调用被装饰的函数,生成一个实例
|
||||
instance = func(*args, **kwargs)
|
||||
if instance is None:
|
||||
|
@ -315,68 +349,36 @@ def bulk_create_decorator(instance_model, batch_size=50, ignore_conflict=True):
|
|||
|
||||
# 如果缓存大小达到批量保存阈值,执行保存
|
||||
if len(cache) >= batch_size:
|
||||
print(f"Batch size reached. Saving {len(cache)} instances...")
|
||||
instance_model.objects.bulk_create(cache, ignore_conflict=ignore_conflict)
|
||||
cache.clear()
|
||||
handle_it()
|
||||
|
||||
reset_timer()
|
||||
return instance
|
||||
|
||||
# 提交剩余实例的方法
|
||||
def commit():
|
||||
nonlocal cache
|
||||
if cache:
|
||||
print(f"Committing remaining {len(cache)} instances...")
|
||||
instance_model.objects.bulk_create(cache)
|
||||
cache.clear()
|
||||
def handle_remaining():
|
||||
if not cache:
|
||||
return
|
||||
print("Timer expired. Saving remaining instances.")
|
||||
from orgs.utils import tmp_to_org
|
||||
with tmp_to_org(org_id):
|
||||
handle_it()
|
||||
|
||||
wrapper.finish = commit
|
||||
wrapper.finish = handle_remaining
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def bulk_update_decorator(instance_model, batch_size=50, update_fields=None):
|
||||
"""
|
||||
装饰器,用于批量更新实例,并提供 `commit` 方法提交剩余的更新。
|
||||
def bulk_create_decorator(instance_model, batch_size=50, ignore_conflicts=True, timeout=0.3):
|
||||
def handle(cache):
|
||||
instance_model.objects.bulk_create(cache, ignore_conflicts=ignore_conflicts)
|
||||
|
||||
:param instance_model: Django模型类,用于调用 bulk_update 方法。
|
||||
:param batch_size: 批量更新的阈值,默认50。
|
||||
:param update_fields: 指定要更新的字段列表,默认为 None,表示更新所有字段。
|
||||
"""
|
||||
def decorator(func):
|
||||
cache = [] # 缓存待更新实例的列表
|
||||
return bulk_handle(handle, batch_size, timeout)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal cache
|
||||
|
||||
# 调用被装饰的函数,获取一个需要更新的实例
|
||||
instance = func(*args, **kwargs)
|
||||
if instance is None:
|
||||
return None
|
||||
def bulk_update_decorator(instance_model, batch_size=50, update_fields=None, timeout=0.3):
|
||||
def handle(cache):
|
||||
instance_model.objects.bulk_update(cache, update_fields)
|
||||
|
||||
# 添加实例到缓存
|
||||
cache.append(instance)
|
||||
print(f"Instance added to update cache. Cache size: {len(cache)}")
|
||||
return bulk_handle(handle, batch_size, timeout)
|
||||
|
||||
# 如果缓存大小达到批量更新阈值,执行更新
|
||||
if len(cache) >= batch_size:
|
||||
print(f"Batch size reached. Updating {len(cache)} instances...")
|
||||
instance_model.objects.bulk_update(cache, update_fields)
|
||||
cache.clear()
|
||||
|
||||
return instance
|
||||
|
||||
# 提交剩余更新的方法
|
||||
def commit():
|
||||
nonlocal cache
|
||||
if cache:
|
||||
print(f"Committing remaining {len(cache)} instances..., {update_fields}")
|
||||
instance_model.objects.bulk_update(cache, update_fields)
|
||||
cache.clear()
|
||||
|
||||
# 将 commit 方法绑定到装饰后的函数
|
||||
wrapper.finish = commit
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
Loading…
Reference in New Issue