perf: 优化 api sql 查询

pull/9454/head
ibuler 2023-02-07 16:21:26 +08:00
parent 47c5f18c6e
commit e7202ac984
13 changed files with 150 additions and 104 deletions

View File

@ -18,7 +18,5 @@
arch: "{{ ansible_architecture }}" arch: "{{ ansible_architecture }}"
kernel: "{{ ansible_kernel }}" kernel: "{{ ansible_kernel }}"
- debug: - debug:
var: info var: info

View File

@ -87,7 +87,11 @@ class QuerySetMixin:
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()
if hasattr(self, 'action') and (self.action == 'list' or self.action == 'metadata'): if not hasattr(self, 'action'):
return queryset
if self.action == 'metadata':
queryset = queryset.none()
if self.action in ['list', 'metadata']:
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
if serializer_class and hasattr(serializer_class, 'setup_eager_loading'): if serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
queryset = serializer_class.setup_eager_loading(queryset) queryset = serializer_class.setup_eager_loading(queryset)

View File

@ -1,4 +1,4 @@
from collections import Iterable, defaultdict from collections import Iterable, defaultdict, OrderedDict
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db.models import NOT_PROVIDED from django.db.models import NOT_PROVIDED
@ -8,8 +8,8 @@ from rest_framework.fields import SkipField, empty
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html from rest_framework.utils import html
from common.serializers.fields import EncryptedField from common.db.fields import EncryptMixin
from common.serializers.fields import LabeledChoiceField, ObjectRelatedField from common.serializers.fields import EncryptedField, LabeledChoiceField, ObjectRelatedField
__all__ = [ __all__ = [
'BulkSerializerMixin', 'BulkListSerializerMixin', 'BulkSerializerMixin', 'BulkListSerializerMixin',
@ -268,6 +268,7 @@ class DefaultValueFieldsMixin:
if not hasattr(self.Meta, 'model'): if not hasattr(self.Meta, 'model'):
return return
model = self.Meta.model model = self.Meta.model
for name, serializer_field in self.fields.items(): for name, serializer_field in self.fields.items():
if serializer_field.default != empty or serializer_field.required: if serializer_field.default != empty or serializer_field.required:
continue continue
@ -335,22 +336,38 @@ class SomeFieldsMixin:
return value return value
return default return default
@staticmethod
def order_fields(fields):
bool_fields = []
datetime_fields = []
other_fields = []
for name, field in fields.items():
to_add = (name, field)
if isinstance(field, serializers.BooleanField):
bool_fields.append(to_add)
elif isinstance(field, serializers.DateTimeField):
datetime_fields.append(to_add)
else:
other_fields.append(to_add)
_fields = [*other_fields, *bool_fields, *datetime_fields]
fields = OrderedDict()
for name, field in _fields:
fields[name] = field
return fields
def get_fields(self): def get_fields(self):
fields = super().get_fields() fields = super().get_fields()
fields = self.order_fields(fields)
secret_readable = isinstance(self, SecretReadableMixin)
for name, field in fields.items(): for name, field in fields.items():
if name == 'id': if name == 'id':
field.label = 'ID' field.label = 'ID'
elif name in self.secret_fields and \ elif isinstance(field, EncryptMixin) and not secret_readable:
not isinstance(self, SecretReadableMixin):
field.write_only = True field.write_only = True
return fields return fields
def get_field_names(self, declared_fields, info):
names = super().get_field_names(declared_fields, info)
common_names = [i for i in self.common_fields if i in names]
primary_names = [i for i in names if i not in self.common_fields]
return primary_names + common_names
class CommonSerializerMixin(DynamicFieldsMixin, RelatedModelSerializerMixin, class CommonSerializerMixin(DynamicFieldsMixin, RelatedModelSerializerMixin,
SomeFieldsMixin, DefaultValueFieldsMixin): SomeFieldsMixin, DefaultValueFieldsMixin):

View File

@ -60,12 +60,6 @@ def on_request_finished_logging_db_query(sender, **kwargs):
method = current_request.method method = current_request.method
path = current_request.get_full_path() path = current_request.get_full_path()
logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters:
logger.debug("Query {:3} times using {:.2f}s {}".format(
counter.counter, counter.time, name)
)
# print(">>> [{}] {}".format(method, path)) # print(">>> [{}] {}".format(method, path))
# for table_name, queries in table_queries.items(): # for table_name, queries in table_queries.items():
# if table_name.startswith('rbac_') or table_name.startswith('auth_permission'): # if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
@ -77,6 +71,12 @@ def on_request_finished_logging_db_query(sender, **kwargs):
# continue # continue
# print('\t{}. {}'.format(i, sql)) # print('\t{}. {}'.format(i, sql))
logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters:
logger.debug("Query {:3} times using {:.2f}s {}".format(
counter.counter, counter.time, name)
)
on_request_finished_release_local(sender, **kwargs) on_request_finished_release_local(sender, **kwargs)

View File

@ -34,9 +34,6 @@ class OrgQuerySetMixin:
% self.__class__.__name__ % self.__class__.__name__
) )
queryset = super().get_queryset() queryset = super().get_queryset()
if hasattr(self, 'swagger_fake_view'):
return queryset.none()
return queryset return queryset

View File

@ -133,23 +133,27 @@ class Organization(OrgRoleMixin, JMSBaseModel):
def org_id(self): def org_id(self):
return self.id return self.id
@classmethod
def get_or_create_builtin(cls, name, **kwargs):
_id = kwargs.get('id')
org = cls.get_instance(cls.DEFAULT_ID)
if org:
return org
org, created = cls.objects.get_or_create(name=name, defaults=kwargs)
if created:
org.builtin = True
org.save()
return org
@classmethod @classmethod
def default(cls): def default(cls):
defaults = dict(id=cls.DEFAULT_ID, name=cls.DEFAULT_NAME) kwargs = {'id': cls.DEFAULT_ID, 'name': cls.DEFAULT_NAME}
obj, created = cls.objects.get_or_create(defaults=defaults, id=cls.DEFAULT_ID) return cls.get_or_create_builtin(**kwargs)
if not obj.builtin:
obj.builtin = True
obj.save()
return obj
@classmethod @classmethod
def system(cls): def system(cls):
defaults = dict(id=cls.SYSTEM_ID, name=cls.SYSTEM_NAME) kwargs = {'id': cls.SYSTEM_ID, 'name': cls.SYSTEM_NAME}
obj, created = cls.objects.get_or_create(defaults=defaults, id=cls.SYSTEM_ID) return cls.get_or_create_builtin(**kwargs)
if not obj.builtin:
obj.builtin = True
obj.save()
return obj
@classmethod @classmethod
def root(cls): def root(cls):

View File

@ -1,12 +1,13 @@
from django.db.models import Count from django.db.models import Q, Count
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import PermissionDenied from rest_framework.exceptions import PermissionDenied
from common.api import JMSModelViewSet from common.api import JMSModelViewSet
from orgs.utils import current_org
from .permission import PermissionViewSet from .permission import PermissionViewSet
from ..filters import RoleFilter from ..filters import RoleFilter
from ..models import Role, SystemRole, OrgRole from ..models import Role, SystemRole, OrgRole, RoleBinding
from ..serializers import RoleSerializer, RoleUserSerializer from ..serializers import RoleSerializer, RoleUserSerializer
__all__ = [ __all__ = [
@ -33,6 +34,7 @@ class RoleViewSet(JMSModelViewSet):
if instance.builtin: if instance.builtin:
error = _("Internal role, can't be destroy") error = _("Internal role, can't be destroy")
raise PermissionDenied(error) raise PermissionDenied(error)
with tmp_to_root_org(): with tmp_to_root_org():
if instance.users.count() >= 1: if instance.users.count() >= 1:
error = _("The role has been bound to users, can't be destroy") error = _("The role has been bound to users, can't be destroy")
@ -54,6 +56,24 @@ class RoleViewSet(JMSModelViewSet):
return return
instance.permissions.set(clone.get_permissions()) instance.permissions.set(clone.get_permissions())
@staticmethod
def set_users_amount(queryset):
"""设置角色的用户绑定数量,以减少查询"""
org_id = current_org.id
q = Q(role__scope=Role.Scope.system) | Q(role__scope=Role.Scope.org, org_id=org_id)
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(user_count=Count('user_id'))
role_user_amount_mapper = {role_id: user_count for role_id, user_count in role_bindings}
queryset = queryset.annotate(permissions_amount=Count('permissions'))
queryset = list(queryset)
for role in queryset:
role.users_amount = role_user_amount_mapper.get(role.id, 0)
return queryset
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.set_users_amount(queryset)
return queryset
def perform_update(self, serializer): def perform_update(self, serializer):
instance = serializer.instance instance = serializer.instance
if instance.builtin: if instance.builtin:
@ -61,10 +81,6 @@ class RoleViewSet(JMSModelViewSet):
raise PermissionDenied(error) raise PermissionDenied(error)
return super().perform_update(serializer) return super().perform_update(serializer)
def get_queryset(self):
queryset = super().get_queryset().annotate(permissions_amount=Count('permissions'))
return queryset
@action(methods=['GET'], detail=True) @action(methods=['GET'], detail=True)
def users(self, *args, **kwargs): def users(self, *args, **kwargs):
role = self.get_object() role = self.get_object()
@ -73,11 +89,13 @@ class RoleViewSet(JMSModelViewSet):
class SystemRoleViewSet(RoleViewSet): class SystemRoleViewSet(RoleViewSet):
queryset = SystemRole.objects.all() def get_queryset(self):
return super().get_queryset().filter(scope='system')
class OrgRoleViewSet(RoleViewSet): class OrgRoleViewSet(RoleViewSet):
queryset = OrgRole.objects.all() def get_queryset(self):
return super().get_queryset().filter(scope='org')
class BaseRolePermissionsViewSet(PermissionViewSet): class BaseRolePermissionsViewSet(PermissionViewSet):

View File

@ -102,7 +102,7 @@ class Role(JMSBaseModel):
@lazyproperty @lazyproperty
def users_amount(self): def users_amount(self):
return self.users.count() return 0
@lazyproperty @lazyproperty
def permissions_amount(self): def permissions_amount(self):

View File

@ -92,12 +92,17 @@ class RBACPermission(permissions.DjangoModelPermissions):
try: try:
queryset = self._queryset(view) queryset = self._queryset(view)
model_cls = queryset.model if isinstance(queryset, list) and queryset:
model_cls = queryset[0].__class__
else:
model_cls = queryset.model
except AssertionError: except AssertionError:
model_cls = None model_cls = None
except AttributeError:
model_cls = None
except Exception as e: except Exception as e:
logger.error('Error get model class: {} of {}'.format(e, view)) logger.error('Error get model class: {} of {}'.format(e, view))
model_cls = None raise e
return model_cls return model_cls
def get_require_perms(self, request, view): def get_require_perms(self, request, view):

View File

@ -330,7 +330,7 @@ class RoleMixin:
id: str id: str
_org_roles = None _org_roles = None
_system_roles = None _system_roles = None
PERM_CACHE_KEY = 'USER_PERMS_{}_{}' PERM_CACHE_KEY = 'USER_PERMS_ROLES_{}_{}'
_is_superuser = None _is_superuser = None
_update_superuser = False _update_superuser = False
@ -347,13 +347,36 @@ class RoleMixin:
return SystemRoleManager(self) return SystemRoleManager(self)
@lazyproperty @lazyproperty
def perms(self): def console_orgs(self):
return self.cached_role_and_perms['console_orgs']
@lazyproperty
def audit_orgs(self):
return self.cached_role_and_perms['audit_orgs']
@lazyproperty
def workbench_orgs(self):
return self.cached_role_and_perms['workbench_orgs']
@lazyproperty
def cached_role_and_perms(self):
from rbac.models import RoleBinding
key = self.PERM_CACHE_KEY.format(self.id, current_org.id) key = self.PERM_CACHE_KEY.format(self.id, current_org.id)
perms = cache.get(key) data = cache.get(key)
if not perms or settings.DEBUG: if data:
perms = self.get_all_permissions() return data
cache.set(key, perms, 3600)
return perms data = {
'console_orgs': RoleBinding.get_user_has_the_perm_orgs('rbac.view_console', self),
'audit_orgs': RoleBinding.get_user_has_the_perm_orgs('rbac.view_audit', self),
'workbench_orgs': RoleBinding.get_user_has_the_perm_orgs('rbac.view_workbench', self),
'org_roles': self.org_roles.all(),
'system_roles': self.system_roles.all(),
'perms': self.get_all_permissions(),
}
cache.set(key, data, 60 * 60)
return data
def expire_rbac_perms_cache(self): def expire_rbac_perms_cache(self):
key = self.PERM_CACHE_KEY.format(self.id, '*') key = self.PERM_CACHE_KEY.format(self.id, '*')
@ -364,6 +387,10 @@ class RoleMixin:
key = cls.PERM_CACHE_KEY.format('*', '*') key = cls.PERM_CACHE_KEY.format('*', '*')
cache.delete_pattern(key) cache.delete_pattern(key)
@lazyproperty
def perms(self):
return self.cached_role_and_perms['perms']
@property @property
def is_superuser(self): def is_superuser(self):
""" """
@ -746,18 +773,6 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
def receive_backends(self): def receive_backends(self):
return self.user_msg_subscription.receive_backends return self.user_msg_subscription.receive_backends
@property
def is_wecom_bound(self):
return bool(self.wecom_id)
@property
def is_dingtalk_bound(self):
return bool(self.dingtalk_id)
@property
def is_feishu_bound(self):
return bool(self.feishu_id)
@property @property
def is_otp_secret_key_bound(self): def is_otp_secret_key_bound(self):
return bool(self.otp_secret_key) return bool(self.otp_secret_key)
@ -765,10 +780,6 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
def get_absolute_url(self): def get_absolute_url(self):
return reverse('users:user-detail', args=(self.id,)) return reverse('users:user-detail', args=(self.id,))
@property
def groups_display(self):
return ' '.join([group.name for group in self.groups.all()])
@property @property
def source_display(self): def source_display(self):
return self.get_source_display() return self.get_source_display()
@ -808,7 +819,7 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
oauth2 = self.Source.oauth2 oauth2 = self.Source.oauth2
return self.source not in [cas, saml2, oauth2] return self.source not in [cas, saml2, oauth2]
def set_unprovide_attr_if_need(self): def set_required_attr_if_need(self):
if not self.name: if not self.name:
self.name = self.username self.name = self.username
if not self.email or '@' not in self.email: if not self.email or '@' not in self.email:
@ -818,7 +829,7 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
self.email = email self.email = email
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
self.set_unprovide_attr_if_need() self.set_required_attr_if_need()
if self.username == 'admin': if self.username == 'admin':
self.role = 'Admin' self.role = 'Admin'
self.is_active = True self.is_active = True
@ -880,21 +891,6 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
return None return None
return self.SOURCE_BACKEND_MAPPING.get(self.source, []) return self.SOURCE_BACKEND_MAPPING.get(self.source, [])
@lazyproperty
def console_orgs(self):
from rbac.models import RoleBinding
return RoleBinding.get_user_has_the_perm_orgs('rbac.view_console', self)
@lazyproperty
def audit_orgs(self):
from rbac.models import RoleBinding
return RoleBinding.get_user_has_the_perm_orgs('rbac.view_audit', self)
@lazyproperty
def workbench_orgs(self):
from rbac.models import RoleBinding
return RoleBinding.get_user_has_the_perm_orgs('rbac.view_workbench', self)
class Meta: class Meta:
ordering = ['username'] ordering = ['username']
verbose_name = _("User") verbose_name = _("User")

View File

@ -2,11 +2,10 @@ from django.conf import settings
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from common.utils import validate_ssh_public_key
from common.serializers.fields import EncryptedField, LabeledChoiceField from common.serializers.fields import EncryptedField, LabeledChoiceField
from ..models import User from common.utils import validate_ssh_public_key
from .user import UserSerializer from .user import UserSerializer
from ..models import User
class UserOrgSerializer(serializers.Serializer): class UserOrgSerializer(serializers.Serializer):
@ -116,7 +115,6 @@ class UserProfileSerializer(UserSerializer):
(0, _('Disable')), (0, _('Disable')),
(1, _('Enable')), (1, _('Enable')),
) )
public_key_comment = serializers.CharField( public_key_comment = serializers.CharField(
source='get_public_key_comment', required=False, read_only=True, max_length=128 source='get_public_key_comment', required=False, read_only=True, max_length=128
) )

View File

@ -1,9 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
# #
import argparse
import os import os
import sys import sys
import django import django
import argparse
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
APPS_DIR = os.path.join(BASE_DIR, 'apps') APPS_DIR = os.path.join(BASE_DIR, 'apps')
@ -17,7 +18,6 @@ from resources.users import UserGroupGenerator, UserGenerator
from resources.perms import AssetPermissionGenerator from resources.perms import AssetPermissionGenerator
from resources.terminal import CommandGenerator, SessionGenerator from resources.terminal import CommandGenerator, SessionGenerator
resource_generator_mapper = { resource_generator_mapper = {
'asset': AssetsGenerator, 'asset': AssetsGenerator,
'platform': PlatformGenerator, 'platform': PlatformGenerator,
@ -26,7 +26,8 @@ resource_generator_mapper = {
'user_group': UserGroupGenerator, 'user_group': UserGroupGenerator,
'asset_permission': AssetPermissionGenerator, 'asset_permission': AssetPermissionGenerator,
'command': CommandGenerator, 'command': CommandGenerator,
'session': SessionGenerator 'session': SessionGenerator,
'all': None
# 'stat': StatGenerator # 'stat': StatGenerator
} }
@ -36,16 +37,24 @@ def main():
parser.add_argument( parser.add_argument(
'resource', type=str, 'resource', type=str,
choices=resource_generator_mapper.keys(), choices=resource_generator_mapper.keys(),
default='all',
help="resource to generate" help="resource to generate"
) )
parser.add_argument('-c', '--count', type=int, default=100) parser.add_argument('-c', '--count', type=int, default=10000)
parser.add_argument('-b', '--batch_size', type=int, default=100) parser.add_argument('-b', '--batch_size', type=int, default=100)
parser.add_argument('-o', '--org', type=str, default='') parser.add_argument('-o', '--org', type=str, default='')
args = parser.parse_args() args = parser.parse_args()
resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org
generator_cls = resource_generator_mapper[resource]
generator = generator_cls(org_id=org_id, batch_size=batch_size) generator_cls = []
generator.generate(count) if resource == 'all':
generator_cls = resource_generator_mapper.values()
else:
generator_cls.push(resource_generator_mapper[resource])
for _cls in generator_cls:
generator = _cls(org_id=org_id, batch_size=batch_size)
generator.generate(count)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,11 +1,11 @@
from random import choice
import random import random
from random import choice
import forgery_py import forgery_py
from .base import FakeDataGenerator
from assets.models import *
from assets.const import AllTypes from assets.const import AllTypes
from assets.models import *
from .base import FakeDataGenerator
class NodesGenerator(FakeDataGenerator): class NodesGenerator(FakeDataGenerator):
@ -59,11 +59,11 @@ class AssetsGenerator(FakeDataGenerator):
assets = [] assets = []
for i in batch: for i in batch:
ip = forgery_py.internet.ip_v4() address = forgery_py.internet.ip_v4()
hostname = forgery_py.email.address().replace('@', '.') hostname = forgery_py.email.address().replace('@', '.')
hostname = f'{hostname}-{ip}' hostname = f'{hostname}-{address}'
data = dict( data = dict(
ip=ip, address=address,
name=hostname, name=hostname,
platform_id=choice(self.platform_ids), platform_id=choice(self.platform_ids),
created_by='Fake', created_by='Fake',