mirror of https://github.com/jumpserver/jumpserver
perf: 优化 api sql 查询
parent
47c5f18c6e
commit
e7202ac984
|
@ -18,7 +18,5 @@
|
|||
arch: "{{ ansible_architecture }}"
|
||||
kernel: "{{ ansible_kernel }}"
|
||||
|
||||
|
||||
|
||||
- debug:
|
||||
var: info
|
||||
|
|
|
@ -87,7 +87,11 @@ class QuerySetMixin:
|
|||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
if hasattr(self, 'action') and (self.action == 'list' or self.action == 'metadata'):
|
||||
if not hasattr(self, 'action'):
|
||||
return queryset
|
||||
if self.action == 'metadata':
|
||||
queryset = queryset.none()
|
||||
if self.action in ['list', 'metadata']:
|
||||
serializer_class = self.get_serializer_class()
|
||||
if serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
|
||||
queryset = serializer_class.setup_eager_loading(queryset)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from collections import Iterable, defaultdict
|
||||
from collections import Iterable, defaultdict, OrderedDict
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db.models import NOT_PROVIDED
|
||||
|
@ -8,8 +8,8 @@ from rest_framework.fields import SkipField, empty
|
|||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import html
|
||||
|
||||
from common.serializers.fields import EncryptedField
|
||||
from common.serializers.fields import LabeledChoiceField, ObjectRelatedField
|
||||
from common.db.fields import EncryptMixin
|
||||
from common.serializers.fields import EncryptedField, LabeledChoiceField, ObjectRelatedField
|
||||
|
||||
__all__ = [
|
||||
'BulkSerializerMixin', 'BulkListSerializerMixin',
|
||||
|
@ -268,6 +268,7 @@ class DefaultValueFieldsMixin:
|
|||
if not hasattr(self.Meta, 'model'):
|
||||
return
|
||||
model = self.Meta.model
|
||||
|
||||
for name, serializer_field in self.fields.items():
|
||||
if serializer_field.default != empty or serializer_field.required:
|
||||
continue
|
||||
|
@ -335,22 +336,38 @@ class SomeFieldsMixin:
|
|||
return value
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def order_fields(fields):
|
||||
bool_fields = []
|
||||
datetime_fields = []
|
||||
other_fields = []
|
||||
|
||||
for name, field in fields.items():
|
||||
to_add = (name, field)
|
||||
if isinstance(field, serializers.BooleanField):
|
||||
bool_fields.append(to_add)
|
||||
elif isinstance(field, serializers.DateTimeField):
|
||||
datetime_fields.append(to_add)
|
||||
else:
|
||||
other_fields.append(to_add)
|
||||
_fields = [*other_fields, *bool_fields, *datetime_fields]
|
||||
fields = OrderedDict()
|
||||
for name, field in _fields:
|
||||
fields[name] = field
|
||||
return fields
|
||||
|
||||
def get_fields(self):
|
||||
fields = super().get_fields()
|
||||
fields = self.order_fields(fields)
|
||||
secret_readable = isinstance(self, SecretReadableMixin)
|
||||
|
||||
for name, field in fields.items():
|
||||
if name == 'id':
|
||||
field.label = 'ID'
|
||||
elif name in self.secret_fields and \
|
||||
not isinstance(self, SecretReadableMixin):
|
||||
elif isinstance(field, EncryptMixin) and not secret_readable:
|
||||
field.write_only = True
|
||||
return fields
|
||||
|
||||
def get_field_names(self, declared_fields, info):
|
||||
names = super().get_field_names(declared_fields, info)
|
||||
common_names = [i for i in self.common_fields if i in names]
|
||||
primary_names = [i for i in names if i not in self.common_fields]
|
||||
return primary_names + common_names
|
||||
|
||||
|
||||
class CommonSerializerMixin(DynamicFieldsMixin, RelatedModelSerializerMixin,
|
||||
SomeFieldsMixin, DefaultValueFieldsMixin):
|
||||
|
|
|
@ -60,12 +60,6 @@ def on_request_finished_logging_db_query(sender, **kwargs):
|
|||
method = current_request.method
|
||||
path = current_request.get_full_path()
|
||||
|
||||
logger.debug(">>> [{}] {}".format(method, path))
|
||||
for name, counter in counters:
|
||||
logger.debug("Query {:3} times using {:.2f}s {}".format(
|
||||
counter.counter, counter.time, name)
|
||||
)
|
||||
|
||||
# print(">>> [{}] {}".format(method, path))
|
||||
# for table_name, queries in table_queries.items():
|
||||
# if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
|
||||
|
@ -77,6 +71,12 @@ def on_request_finished_logging_db_query(sender, **kwargs):
|
|||
# continue
|
||||
# print('\t{}. {}'.format(i, sql))
|
||||
|
||||
logger.debug(">>> [{}] {}".format(method, path))
|
||||
for name, counter in counters:
|
||||
logger.debug("Query {:3} times using {:.2f}s {}".format(
|
||||
counter.counter, counter.time, name)
|
||||
)
|
||||
|
||||
on_request_finished_release_local(sender, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -34,9 +34,6 @@ class OrgQuerySetMixin:
|
|||
% self.__class__.__name__
|
||||
)
|
||||
queryset = super().get_queryset()
|
||||
|
||||
if hasattr(self, 'swagger_fake_view'):
|
||||
return queryset.none()
|
||||
return queryset
|
||||
|
||||
|
||||
|
|
|
@ -133,23 +133,27 @@ class Organization(OrgRoleMixin, JMSBaseModel):
|
|||
def org_id(self):
|
||||
return self.id
|
||||
|
||||
@classmethod
|
||||
def get_or_create_builtin(cls, name, **kwargs):
|
||||
_id = kwargs.get('id')
|
||||
org = cls.get_instance(cls.DEFAULT_ID)
|
||||
if org:
|
||||
return org
|
||||
org, created = cls.objects.get_or_create(name=name, defaults=kwargs)
|
||||
if created:
|
||||
org.builtin = True
|
||||
org.save()
|
||||
return org
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
defaults = dict(id=cls.DEFAULT_ID, name=cls.DEFAULT_NAME)
|
||||
obj, created = cls.objects.get_or_create(defaults=defaults, id=cls.DEFAULT_ID)
|
||||
if not obj.builtin:
|
||||
obj.builtin = True
|
||||
obj.save()
|
||||
return obj
|
||||
kwargs = {'id': cls.DEFAULT_ID, 'name': cls.DEFAULT_NAME}
|
||||
return cls.get_or_create_builtin(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def system(cls):
|
||||
defaults = dict(id=cls.SYSTEM_ID, name=cls.SYSTEM_NAME)
|
||||
obj, created = cls.objects.get_or_create(defaults=defaults, id=cls.SYSTEM_ID)
|
||||
if not obj.builtin:
|
||||
obj.builtin = True
|
||||
obj.save()
|
||||
return obj
|
||||
kwargs = {'id': cls.SYSTEM_ID, 'name': cls.SYSTEM_NAME}
|
||||
return cls.get_or_create_builtin(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def root(cls):
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from django.db.models import Count
|
||||
from django.db.models import Q, Count
|
||||
from django.utils.translation import ugettext as _
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
|
||||
from common.api import JMSModelViewSet
|
||||
from orgs.utils import current_org
|
||||
from .permission import PermissionViewSet
|
||||
from ..filters import RoleFilter
|
||||
from ..models import Role, SystemRole, OrgRole
|
||||
from ..models import Role, SystemRole, OrgRole, RoleBinding
|
||||
from ..serializers import RoleSerializer, RoleUserSerializer
|
||||
|
||||
__all__ = [
|
||||
|
@ -33,6 +34,7 @@ class RoleViewSet(JMSModelViewSet):
|
|||
if instance.builtin:
|
||||
error = _("Internal role, can't be destroy")
|
||||
raise PermissionDenied(error)
|
||||
|
||||
with tmp_to_root_org():
|
||||
if instance.users.count() >= 1:
|
||||
error = _("The role has been bound to users, can't be destroy")
|
||||
|
@ -54,6 +56,24 @@ class RoleViewSet(JMSModelViewSet):
|
|||
return
|
||||
instance.permissions.set(clone.get_permissions())
|
||||
|
||||
@staticmethod
|
||||
def set_users_amount(queryset):
|
||||
"""设置角色的用户绑定数量,以减少查询"""
|
||||
org_id = current_org.id
|
||||
q = Q(role__scope=Role.Scope.system) | Q(role__scope=Role.Scope.org, org_id=org_id)
|
||||
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(user_count=Count('user_id'))
|
||||
role_user_amount_mapper = {role_id: user_count for role_id, user_count in role_bindings}
|
||||
queryset = queryset.annotate(permissions_amount=Count('permissions'))
|
||||
queryset = list(queryset)
|
||||
for role in queryset:
|
||||
role.users_amount = role_user_amount_mapper.get(role.id, 0)
|
||||
return queryset
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
queryset = super().filter_queryset(queryset)
|
||||
queryset = self.set_users_amount(queryset)
|
||||
return queryset
|
||||
|
||||
def perform_update(self, serializer):
|
||||
instance = serializer.instance
|
||||
if instance.builtin:
|
||||
|
@ -61,10 +81,6 @@ class RoleViewSet(JMSModelViewSet):
|
|||
raise PermissionDenied(error)
|
||||
return super().perform_update(serializer)
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset().annotate(permissions_amount=Count('permissions'))
|
||||
return queryset
|
||||
|
||||
@action(methods=['GET'], detail=True)
|
||||
def users(self, *args, **kwargs):
|
||||
role = self.get_object()
|
||||
|
@ -73,11 +89,13 @@ class RoleViewSet(JMSModelViewSet):
|
|||
|
||||
|
||||
class SystemRoleViewSet(RoleViewSet):
|
||||
queryset = SystemRole.objects.all()
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().filter(scope='system')
|
||||
|
||||
|
||||
class OrgRoleViewSet(RoleViewSet):
|
||||
queryset = OrgRole.objects.all()
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().filter(scope='org')
|
||||
|
||||
|
||||
class BaseRolePermissionsViewSet(PermissionViewSet):
|
||||
|
|
|
@ -102,7 +102,7 @@ class Role(JMSBaseModel):
|
|||
|
||||
@lazyproperty
|
||||
def users_amount(self):
|
||||
return self.users.count()
|
||||
return 0
|
||||
|
||||
@lazyproperty
|
||||
def permissions_amount(self):
|
||||
|
|
|
@ -92,12 +92,17 @@ class RBACPermission(permissions.DjangoModelPermissions):
|
|||
|
||||
try:
|
||||
queryset = self._queryset(view)
|
||||
model_cls = queryset.model
|
||||
if isinstance(queryset, list) and queryset:
|
||||
model_cls = queryset[0].__class__
|
||||
else:
|
||||
model_cls = queryset.model
|
||||
except AssertionError:
|
||||
model_cls = None
|
||||
except AttributeError:
|
||||
model_cls = None
|
||||
except Exception as e:
|
||||
logger.error('Error get model class: {} of {}'.format(e, view))
|
||||
model_cls = None
|
||||
raise e
|
||||
return model_cls
|
||||
|
||||
def get_require_perms(self, request, view):
|
||||
|
|
|
@ -330,7 +330,7 @@ class RoleMixin:
|
|||
id: str
|
||||
_org_roles = None
|
||||
_system_roles = None
|
||||
PERM_CACHE_KEY = 'USER_PERMS_{}_{}'
|
||||
PERM_CACHE_KEY = 'USER_PERMS_ROLES_{}_{}'
|
||||
_is_superuser = None
|
||||
_update_superuser = False
|
||||
|
||||
|
@ -347,13 +347,36 @@ class RoleMixin:
|
|||
return SystemRoleManager(self)
|
||||
|
||||
@lazyproperty
|
||||
def perms(self):
|
||||
def console_orgs(self):
|
||||
return self.cached_role_and_perms['console_orgs']
|
||||
|
||||
@lazyproperty
|
||||
def audit_orgs(self):
|
||||
return self.cached_role_and_perms['audit_orgs']
|
||||
|
||||
@lazyproperty
|
||||
def workbench_orgs(self):
|
||||
return self.cached_role_and_perms['workbench_orgs']
|
||||
|
||||
@lazyproperty
|
||||
def cached_role_and_perms(self):
|
||||
from rbac.models import RoleBinding
|
||||
|
||||
key = self.PERM_CACHE_KEY.format(self.id, current_org.id)
|
||||
perms = cache.get(key)
|
||||
if not perms or settings.DEBUG:
|
||||
perms = self.get_all_permissions()
|
||||
cache.set(key, perms, 3600)
|
||||
return perms
|
||||
data = cache.get(key)
|
||||
if data:
|
||||
return data
|
||||
|
||||
data = {
|
||||
'console_orgs': RoleBinding.get_user_has_the_perm_orgs('rbac.view_console', self),
|
||||
'audit_orgs': RoleBinding.get_user_has_the_perm_orgs('rbac.view_audit', self),
|
||||
'workbench_orgs': RoleBinding.get_user_has_the_perm_orgs('rbac.view_workbench', self),
|
||||
'org_roles': self.org_roles.all(),
|
||||
'system_roles': self.system_roles.all(),
|
||||
'perms': self.get_all_permissions(),
|
||||
}
|
||||
cache.set(key, data, 60 * 60)
|
||||
return data
|
||||
|
||||
def expire_rbac_perms_cache(self):
|
||||
key = self.PERM_CACHE_KEY.format(self.id, '*')
|
||||
|
@ -364,6 +387,10 @@ class RoleMixin:
|
|||
key = cls.PERM_CACHE_KEY.format('*', '*')
|
||||
cache.delete_pattern(key)
|
||||
|
||||
@lazyproperty
|
||||
def perms(self):
|
||||
return self.cached_role_and_perms['perms']
|
||||
|
||||
@property
|
||||
def is_superuser(self):
|
||||
"""
|
||||
|
@ -746,18 +773,6 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
|
|||
def receive_backends(self):
|
||||
return self.user_msg_subscription.receive_backends
|
||||
|
||||
@property
|
||||
def is_wecom_bound(self):
|
||||
return bool(self.wecom_id)
|
||||
|
||||
@property
|
||||
def is_dingtalk_bound(self):
|
||||
return bool(self.dingtalk_id)
|
||||
|
||||
@property
|
||||
def is_feishu_bound(self):
|
||||
return bool(self.feishu_id)
|
||||
|
||||
@property
|
||||
def is_otp_secret_key_bound(self):
|
||||
return bool(self.otp_secret_key)
|
||||
|
@ -765,10 +780,6 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
|
|||
def get_absolute_url(self):
|
||||
return reverse('users:user-detail', args=(self.id,))
|
||||
|
||||
@property
|
||||
def groups_display(self):
|
||||
return ' '.join([group.name for group in self.groups.all()])
|
||||
|
||||
@property
|
||||
def source_display(self):
|
||||
return self.get_source_display()
|
||||
|
@ -808,7 +819,7 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
|
|||
oauth2 = self.Source.oauth2
|
||||
return self.source not in [cas, saml2, oauth2]
|
||||
|
||||
def set_unprovide_attr_if_need(self):
|
||||
def set_required_attr_if_need(self):
|
||||
if not self.name:
|
||||
self.name = self.username
|
||||
if not self.email or '@' not in self.email:
|
||||
|
@ -818,7 +829,7 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
|
|||
self.email = email
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.set_unprovide_attr_if_need()
|
||||
self.set_required_attr_if_need()
|
||||
if self.username == 'admin':
|
||||
self.role = 'Admin'
|
||||
self.is_active = True
|
||||
|
@ -880,21 +891,6 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
|
|||
return None
|
||||
return self.SOURCE_BACKEND_MAPPING.get(self.source, [])
|
||||
|
||||
@lazyproperty
|
||||
def console_orgs(self):
|
||||
from rbac.models import RoleBinding
|
||||
return RoleBinding.get_user_has_the_perm_orgs('rbac.view_console', self)
|
||||
|
||||
@lazyproperty
|
||||
def audit_orgs(self):
|
||||
from rbac.models import RoleBinding
|
||||
return RoleBinding.get_user_has_the_perm_orgs('rbac.view_audit', self)
|
||||
|
||||
@lazyproperty
|
||||
def workbench_orgs(self):
|
||||
from rbac.models import RoleBinding
|
||||
return RoleBinding.get_user_has_the_perm_orgs('rbac.view_workbench', self)
|
||||
|
||||
class Meta:
|
||||
ordering = ['username']
|
||||
verbose_name = _("User")
|
||||
|
|
|
@ -2,11 +2,10 @@ from django.conf import settings
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.utils import validate_ssh_public_key
|
||||
from common.serializers.fields import EncryptedField, LabeledChoiceField
|
||||
from ..models import User
|
||||
|
||||
from common.utils import validate_ssh_public_key
|
||||
from .user import UserSerializer
|
||||
from ..models import User
|
||||
|
||||
|
||||
class UserOrgSerializer(serializers.Serializer):
|
||||
|
@ -116,7 +115,6 @@ class UserProfileSerializer(UserSerializer):
|
|||
(0, _('Disable')),
|
||||
(1, _('Enable')),
|
||||
)
|
||||
|
||||
public_key_comment = serializers.CharField(
|
||||
source='get_public_key_comment', required=False, read_only=True, max_length=128
|
||||
)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import django
|
||||
import argparse
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
APPS_DIR = os.path.join(BASE_DIR, 'apps')
|
||||
|
@ -17,7 +18,6 @@ from resources.users import UserGroupGenerator, UserGenerator
|
|||
from resources.perms import AssetPermissionGenerator
|
||||
from resources.terminal import CommandGenerator, SessionGenerator
|
||||
|
||||
|
||||
resource_generator_mapper = {
|
||||
'asset': AssetsGenerator,
|
||||
'platform': PlatformGenerator,
|
||||
|
@ -26,7 +26,8 @@ resource_generator_mapper = {
|
|||
'user_group': UserGroupGenerator,
|
||||
'asset_permission': AssetPermissionGenerator,
|
||||
'command': CommandGenerator,
|
||||
'session': SessionGenerator
|
||||
'session': SessionGenerator,
|
||||
'all': None
|
||||
# 'stat': StatGenerator
|
||||
}
|
||||
|
||||
|
@ -36,16 +37,24 @@ def main():
|
|||
parser.add_argument(
|
||||
'resource', type=str,
|
||||
choices=resource_generator_mapper.keys(),
|
||||
default='all',
|
||||
help="resource to generate"
|
||||
)
|
||||
parser.add_argument('-c', '--count', type=int, default=100)
|
||||
parser.add_argument('-c', '--count', type=int, default=10000)
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=100)
|
||||
parser.add_argument('-o', '--org', type=str, default='')
|
||||
args = parser.parse_args()
|
||||
resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org
|
||||
generator_cls = resource_generator_mapper[resource]
|
||||
generator = generator_cls(org_id=org_id, batch_size=batch_size)
|
||||
generator.generate(count)
|
||||
|
||||
generator_cls = []
|
||||
if resource == 'all':
|
||||
generator_cls = resource_generator_mapper.values()
|
||||
else:
|
||||
generator_cls.push(resource_generator_mapper[resource])
|
||||
|
||||
for _cls in generator_cls:
|
||||
generator = _cls(org_id=org_id, batch_size=batch_size)
|
||||
generator.generate(count)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from random import choice
|
||||
import random
|
||||
from random import choice
|
||||
|
||||
import forgery_py
|
||||
|
||||
from .base import FakeDataGenerator
|
||||
|
||||
from assets.models import *
|
||||
from assets.const import AllTypes
|
||||
from assets.models import *
|
||||
from .base import FakeDataGenerator
|
||||
|
||||
|
||||
class NodesGenerator(FakeDataGenerator):
|
||||
|
@ -59,11 +59,11 @@ class AssetsGenerator(FakeDataGenerator):
|
|||
assets = []
|
||||
|
||||
for i in batch:
|
||||
ip = forgery_py.internet.ip_v4()
|
||||
address = forgery_py.internet.ip_v4()
|
||||
hostname = forgery_py.email.address().replace('@', '.')
|
||||
hostname = f'{hostname}-{ip}'
|
||||
hostname = f'{hostname}-{address}'
|
||||
data = dict(
|
||||
ip=ip,
|
||||
address=address,
|
||||
name=hostname,
|
||||
platform_id=choice(self.platform_ids),
|
||||
created_by='Fake',
|
||||
|
|
Loading…
Reference in New Issue