# -*- coding: utf-8 -*- # import logging import os import re from collections import defaultdict from django.conf import settings from django.core.signals import request_finished from django.db import connection from django.db.models.signals import pre_save from django.dispatch import receiver from jumpserver.utils import get_current_request from .local import thread_local from .signals import django_ready pattern = re.compile(r'FROM `(\w+)`') logger = logging.getLogger("jumpserver.common") class Counter: def __init__(self): self.counter = 0 self.time = 0 def __gt__(self, other): return self.counter > other.counter def __lt__(self, other): return self.counter < other.counter def __eq__(self, other): return self.counter == other.counter def digest_sql_query(): queries = connection.queries counters = defaultdict(Counter) table_queries = defaultdict(list) for query in queries: if not query['sql'] or not query['sql'].startswith('SELECT'): continue tables = pattern.findall(query['sql']) table_name = ''.join(tables) time = query['time'] counters[table_name].counter += 1 counters[table_name].time += float(time) counters['total'].counter += 1 counters['total'].time += float(time) table_queries[table_name].append(query) counters = sorted(counters.items(), key=lambda x: x[1]) if not counters: return method = 'GET' path = '/Unknown' current_request = get_current_request() if current_request: method = current_request.method path = current_request.get_full_path() print(">>> [{}] {}".format(method, path)) for table_name, queries in table_queries.items(): if table_name.startswith('rbac_') or table_name.startswith('auth_permission'): continue if len(queries) < 3: continue print("- Table: {}".format(table_name)) for i, query in enumerate(queries, 1): sql = query['sql'] if not sql or not sql.startswith('SELECT'): continue print('\t{}. {}'.format(i, sql)) logger.debug(">>> [{}] {}".format(method, path)) for name, counter in counters: logger.debug("Query {:3} times using {:.2f}s {}".format( counter.counter, counter.time, name) ) def on_request_finished_logging_db_query(sender, **kwargs): digest_sql_query() on_request_finished_release_local(sender, **kwargs) def on_request_finished_release_local(sender, **kwargs): thread_local.__release_local__() def _get_request_user_name(): user_name = 'System' current_request = get_current_request() if current_request and current_request.user.is_authenticated: user_name = current_request.user.name if isinstance(user_name, str): user_name = user_name[:30] return user_name @receiver(pre_save) def on_create_set_created_by(sender, instance=None, **kwargs): if getattr(instance, '_ignore_auto_created_by', False): return if not hasattr(instance, 'created_by') or instance.created_by: return user_name = _get_request_user_name() instance.created_by = user_name @receiver(pre_save) def on_update_set_updated_by(sender, instance=None, created=False, **kwargs): if getattr(instance, '_ignore_auto_updated_by', False): return if not hasattr(instance, 'updated_by'): return user_name = _get_request_user_name() instance.updated_by = user_name if settings.DEBUG_DEV: request_finished.connect(on_request_finished_logging_db_query) else: request_finished.connect(on_request_finished_release_local) @receiver(django_ready) def check_migrations_file_prefix_conflict(*args, **kwargs): if not settings.DEBUG_DEV: return from jumpserver.const import BASE_DIR print('>>> Check migrations file prefix conflict.', end=' ') # 指定 app 目录 _dir = BASE_DIR # 获取所有子目录 sub_dirs = next(os.walk(_dir))[1] # 记录冲突的文件,元素为 (subdir, file1, file2) conflict_files = [] # 遍历每个子目录 for subdir in sub_dirs: # 拼接 migrations 目录路径 migrations_dir = os.path.join(_dir, subdir, 'migrations') # 判断是否存在 migrations 目录 if not os.path.exists(migrations_dir): continue # 获取所有文件名 files = os.listdir(migrations_dir) # 遍历每个文件名 prefix_file_map = dict() for file in files: file = str(file) # 判断是否为 Python 文件 if not file.endswith('.py'): continue if 'squashed' in file: continue # file 为文件名 file_prefix = file.split('_')[0] if file_prefix in prefix_file_map.keys(): conflict_files.append((subdir, file, prefix_file_map.get(file_prefix))) else: prefix_file_map[file_prefix] = file conflict_count = len(conflict_files) print(f'Conflict count:({conflict_count})') if not conflict_count: return print('='*80) for conflict_file in conflict_files: msg_dir = '{:<15}'.format(conflict_file[0]) msg_split = '=> ' msg_left = msg_dir msg_right1 = msg_split + '{:<80}'.format(conflict_file[1]) msg_right2 = ' ' * len(msg_left) + msg_split + conflict_file[2] print(f'{msg_left}{msg_right1}\n{msg_right2}\n') print('='*80)