jumpserver/apps/common/signal_handlers.py

185 lines
5.5 KiB
Python

# -*- coding: utf-8 -*-
#
import logging
import os
import re
from collections import defaultdict
from django.conf import settings
from django.core.signals import request_finished
from django.db import connection
from django.db.models.signals import pre_save
from django.dispatch import receiver
from jumpserver.utils import get_current_request
from .local import thread_local
from .signals import django_ready
pattern = re.compile(r'FROM `(\w+)`')
logger = logging.getLogger("jumpserver.common")
class Counter:
def __init__(self):
self.counter = 0
self.time = 0
def __gt__(self, other):
return self.counter > other.counter
def __lt__(self, other):
return self.counter < other.counter
def __eq__(self, other):
return self.counter == other.counter
def digest_sql_query():
queries = connection.queries
counters = defaultdict(Counter)
table_queries = defaultdict(list)
for query in queries:
if not query['sql'] or not query['sql'].startswith('SELECT'):
continue
tables = pattern.findall(query['sql'])
table_name = ''.join(tables)
time = query['time']
counters[table_name].counter += 1
counters[table_name].time += float(time)
counters['total'].counter += 1
counters['total'].time += float(time)
table_queries[table_name].append(query)
counters = sorted(counters.items(), key=lambda x: x[1])
if not counters:
return
method = 'GET'
path = '/Unknown'
current_request = get_current_request()
if current_request:
method = current_request.method
path = current_request.get_full_path()
print(">>> [{}] {}".format(method, path))
for table_name, queries in table_queries.items():
if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
continue
if len(queries) < 3:
continue
print("- Table: {}".format(table_name))
for i, query in enumerate(queries, 1):
sql = query['sql']
if not sql or not sql.startswith('SELECT'):
continue
print('\t{}. {}'.format(i, sql))
logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters:
logger.debug("Query {:3} times using {:.2f}s {}".format(
counter.counter, counter.time, name)
)
def on_request_finished_logging_db_query(sender, **kwargs):
digest_sql_query()
on_request_finished_release_local(sender, **kwargs)
def on_request_finished_release_local(sender, **kwargs):
thread_local.__release_local__()
def _get_request_user_name():
user_name = 'System'
current_request = get_current_request()
if current_request and current_request.user.is_authenticated:
user_name = current_request.user.name
if isinstance(user_name, str):
user_name = user_name[:30]
return user_name
@receiver(pre_save)
def on_create_set_created_by(sender, instance=None, **kwargs):
if getattr(instance, '_ignore_auto_created_by', False):
return
if not hasattr(instance, 'created_by') or instance.created_by:
return
user_name = _get_request_user_name()
instance.created_by = user_name
@receiver(pre_save)
def on_update_set_updated_by(sender, instance=None, created=False, **kwargs):
if getattr(instance, '_ignore_auto_updated_by', False):
return
if not hasattr(instance, 'updated_by'):
return
user_name = _get_request_user_name()
instance.updated_by = user_name
if settings.DEBUG_DEV:
request_finished.connect(on_request_finished_logging_db_query)
else:
request_finished.connect(on_request_finished_release_local)
@receiver(django_ready)
def check_migrations_file_prefix_conflict(*args, **kwargs):
if not settings.DEBUG_DEV:
return
from jumpserver.const import BASE_DIR
print('>>> Check migrations file prefix conflict.', end=' ')
# 指定 app 目录
_dir = BASE_DIR
# 获取所有子目录
sub_dirs = next(os.walk(_dir))[1]
# 记录冲突的文件,元素为 (subdir, file1, file2)
conflict_files = []
# 遍历每个子目录
for subdir in sub_dirs:
# 拼接 migrations 目录路径
migrations_dir = os.path.join(_dir, subdir, 'migrations')
# 判断是否存在 migrations 目录
if not os.path.exists(migrations_dir):
continue
# 获取所有文件名
files = os.listdir(migrations_dir)
# 遍历每个文件名
prefix_file_map = dict()
for file in files:
file = str(file)
# 判断是否为 Python 文件
if not file.endswith('.py'):
continue
if 'squashed' in file:
continue
# file 为文件名
file_prefix = file.split('_')[0]
if file_prefix in prefix_file_map.keys():
conflict_files.append((subdir, file, prefix_file_map.get(file_prefix)))
else:
prefix_file_map[file_prefix] = file
conflict_count = len(conflict_files)
print(f'Conflict count:({conflict_count})')
if not conflict_count:
return
print('='*80)
for conflict_file in conflict_files:
msg_dir = '{:<15}'.format(conflict_file[0])
msg_split = '=> '
msg_left = msg_dir
msg_right1 = msg_split + '{:<80}'.format(conflict_file[1])
msg_right2 = ' ' * len(msg_left) + msg_split + conflict_file[2]
print(f'{msg_left}{msg_right1}\n{msg_right2}\n')
print('='*80)