mirror of https://github.com/jumpserver/jumpserver
feat: 添加缓存模块,添加组织资源统计 (#5407)
* feat: 添加缓存模块,添加组织资源统计 * refactor * recover .gitkeep * refactor * 合并信号处理 * 修复组织添加用户没有发信号 * 修改了一个log级别 Co-authored-by: xinwen <coderWen@126.com>pull/5435/head
parent
f04e2fa090
commit
a7fa2331bd
|
@ -0,0 +1,187 @@
|
|||
import json
|
||||
from django.core.cache import cache
|
||||
|
||||
from common.utils.lock import DistributedLock
|
||||
from common.utils import lazyproperty
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class CacheFieldBase:
|
||||
field_type = str
|
||||
|
||||
def __init__(self, queryset=None, compute_func_name=None):
|
||||
assert None in (queryset, compute_func_name), f'queryset and compute_func_name can only have one'
|
||||
self.compute_func_name = compute_func_name
|
||||
self.queryset = queryset
|
||||
|
||||
|
||||
class CharField(CacheFieldBase):
|
||||
field_type = str
|
||||
|
||||
|
||||
class IntegerField(CacheFieldBase):
|
||||
field_type = int
|
||||
|
||||
|
||||
class CacheBase(type):
|
||||
def __new__(cls, name, bases, attrs: dict):
|
||||
to_update = {}
|
||||
field_desc_mapper = {}
|
||||
|
||||
for k, v in attrs.items():
|
||||
if isinstance(v, CacheFieldBase):
|
||||
desc = CacheValueDesc(k, v)
|
||||
to_update[k] = desc
|
||||
field_desc_mapper[k] = desc
|
||||
|
||||
attrs.update(to_update)
|
||||
attrs['field_desc_mapper'] = field_desc_mapper
|
||||
return type.__new__(cls, name, bases, attrs)
|
||||
|
||||
|
||||
class Cache(metaclass=CacheBase):
|
||||
field_desc_mapper: dict
|
||||
timeout = None
|
||||
|
||||
def __init__(self):
|
||||
self._data = None
|
||||
|
||||
@lazyproperty
|
||||
def key_suffix(self):
|
||||
return self.get_key_suffix()
|
||||
|
||||
@property
|
||||
def key_prefix(self):
|
||||
clz = self.__class__
|
||||
return f'cache.{clz.__module__}.{clz.__name__}'
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return f'{self.key_prefix}.{self.key_suffix}'
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
if self._data is None:
|
||||
data = self.get_data()
|
||||
if data is None:
|
||||
# 缓存中没有数据时,去数据库获取
|
||||
self.compute_and_set_all_data()
|
||||
return self._data
|
||||
|
||||
def get_data(self) -> dict:
|
||||
data = cache.get(self.key)
|
||||
logger.debug(f'CACHE: get {self.key} = {data}')
|
||||
if data is not None:
|
||||
data = json.loads(data)
|
||||
self._data = data
|
||||
return data
|
||||
|
||||
def set_data(self, data):
|
||||
self._data = data
|
||||
to_json = json.dumps(data)
|
||||
logger.info(f'CACHE: set {self.key} = {to_json}, timeout={self.timeout}')
|
||||
cache.set(self.key, to_json, timeout=self.timeout)
|
||||
|
||||
def _compute_data(self, *fields):
|
||||
field_descs = []
|
||||
if not fields:
|
||||
field_descs = self.field_desc_mapper.values()
|
||||
else:
|
||||
for field in fields:
|
||||
assert field in self.field_desc_mapper, f'{field} is not a valid field'
|
||||
field_descs.append(self.field_desc_mapper[field])
|
||||
data = {
|
||||
field_desc.field_name: field_desc.compute_value(self)
|
||||
for field_desc in field_descs
|
||||
}
|
||||
return data
|
||||
|
||||
def compute_and_set_all_data(self, computed_data: dict = None):
|
||||
"""
|
||||
TODO 怎样防止并发更新全部数据,浪费数据库资源
|
||||
"""
|
||||
uncomputed_keys = ()
|
||||
if computed_data:
|
||||
computed_keys = computed_data.keys()
|
||||
all_keys = self.field_desc_mapper.keys()
|
||||
uncomputed_keys = all_keys - computed_keys
|
||||
else:
|
||||
computed_data = {}
|
||||
data = self._compute_data(*uncomputed_keys)
|
||||
data.update(computed_data)
|
||||
self.set_data(data)
|
||||
return data
|
||||
|
||||
def refresh_part_data_with_lock(self, refresh_data):
|
||||
with DistributedLock(name=f'{self.key}.refresh'):
|
||||
data = self.get_data()
|
||||
if data is not None:
|
||||
data.update(refresh_data)
|
||||
self.set_data(data)
|
||||
return data
|
||||
|
||||
def refresh(self, *fields):
|
||||
if not fields:
|
||||
# 没有指定 field 要刷新所有的值
|
||||
self.compute_and_set_all_data()
|
||||
return
|
||||
|
||||
data = self.get_data()
|
||||
if data is None:
|
||||
# 缓存中没有数据,设置所有的值
|
||||
self.compute_and_set_all_data()
|
||||
return
|
||||
|
||||
refresh_data = self._compute_data(*fields)
|
||||
if not self.refresh_part_data_with_lock(refresh_data):
|
||||
# 刷新部分失败,缓存中没有数据,更新所有的值
|
||||
self.compute_and_set_all_data(refresh_data)
|
||||
return
|
||||
|
||||
def get_key_suffix(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def reload(self):
|
||||
self._data = None
|
||||
|
||||
def delete(self):
|
||||
self._data = None
|
||||
logger.info(f'CACHE: delete {self.key}')
|
||||
cache.delete(self.key)
|
||||
|
||||
|
||||
class CacheValueDesc:
|
||||
def __init__(self, field_name, field_type: CacheFieldBase):
|
||||
self.field_name = field_name
|
||||
self.field_type = field_type
|
||||
self._data = None
|
||||
|
||||
def __repr__(self):
|
||||
clz = self.__class__
|
||||
return f'<{clz.__name__} {self.field_name} {self.field_type}>'
|
||||
|
||||
def __get__(self, instance: Cache, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
if self.field_name not in instance.data:
|
||||
instance.refresh(self.field_name)
|
||||
value = instance.data[self.field_name]
|
||||
return value
|
||||
|
||||
def compute_value(self, instance: Cache):
|
||||
if self.field_type.queryset is not None:
|
||||
new_value = self.field_type.queryset.count()
|
||||
else:
|
||||
compute_func_name = self.field_type.compute_func_name
|
||||
if not compute_func_name:
|
||||
compute_func_name = f'compute_{self.field_name}'
|
||||
compute_func = getattr(instance, compute_func_name, None)
|
||||
assert compute_func is not None, \
|
||||
f'Define `{compute_func_name}` method in {instance.__class__}'
|
||||
new_value = compute_func()
|
||||
|
||||
new_value = self.field_type.field_type(new_value)
|
||||
logger.info(f'CACHE: compute {instance.key}.{self.field_name} = {new_value}')
|
||||
return new_value
|
|
@ -12,3 +12,6 @@ PRE_REMOVE = 'pre_remove'
|
|||
POST_REMOVE = 'post_remove'
|
||||
PRE_CLEAR = 'pre_clear'
|
||||
POST_CLEAR = 'post_clear'
|
||||
|
||||
POST_PREFIX = 'post'
|
||||
PRE_PREFIX = 'pre'
|
||||
|
|
|
@ -124,6 +124,22 @@ class BulkListSerializerMixin(object):
|
|||
|
||||
return ret
|
||||
|
||||
def create(self, validated_data):
|
||||
ModelClass = self.child.Meta.model
|
||||
use_model_bulk_create = getattr(self.child.Meta, 'use_model_bulk_create', False)
|
||||
model_bulk_create_kwargs = getattr(self.child.Meta, 'model_bulk_create_kwargs', {})
|
||||
|
||||
if use_model_bulk_create:
|
||||
to_create = [
|
||||
ModelClass(**attrs) for attrs in validated_data
|
||||
]
|
||||
objs = ModelClass._default_manager.bulk_create(
|
||||
to_create, **model_bulk_create_kwargs
|
||||
)
|
||||
return objs
|
||||
else:
|
||||
return super().create(validated_data)
|
||||
|
||||
|
||||
class BaseDynamicFieldsPlugin:
|
||||
def __init__(self, serializer):
|
||||
|
|
|
@ -4,7 +4,6 @@ from celery import shared_task
|
|||
|
||||
from .utils import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from functools import wraps
|
||||
import threading
|
||||
|
||||
from redis_lock import Lock as RedisLock
|
||||
from redis import Redis
|
||||
|
@ -35,11 +36,16 @@ class DistributedLock(RedisLock):
|
|||
self._blocking = blocking
|
||||
|
||||
def __enter__(self):
|
||||
thread_id = threading.current_thread().ident
|
||||
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> attempt to acquire <lock:{self._name}> ...')
|
||||
acquired = self.acquire(blocking=self._blocking)
|
||||
if self._blocking and not acquired:
|
||||
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> was not acquired <lock:{self._name}>, but blocking=True')
|
||||
raise EnvironmentError("Lock wasn't acquired, but blocking=True")
|
||||
if not acquired:
|
||||
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> acquire <lock:{self._name}> failed')
|
||||
raise AcquireFailed
|
||||
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> acquire <lock:{self._name}> ok')
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
|
||||
|
|
|
@ -75,11 +75,6 @@ class OrgMemberRelationBulkViewSet(JMSBulkRelationModelViewSet):
|
|||
filterset_class = OrgMemberRelationFilterSet
|
||||
search_fields = ('user__name', 'user__username', 'org__name')
|
||||
|
||||
def perform_bulk_create(self, serializer):
|
||||
data = serializer.validated_data
|
||||
relations = [OrganizationMember(**i) for i in data]
|
||||
OrganizationMember.objects.bulk_create(relations, ignore_conflicts=True)
|
||||
|
||||
def perform_bulk_destroy(self, queryset):
|
||||
objs = list(queryset.all().prefetch_related('user', 'org'))
|
||||
queryset.delete()
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
from django.db.transaction import on_commit
|
||||
|
||||
from common.cache import *
|
||||
from .utils import current_org, tmp_to_org
|
||||
from .tasks import refresh_org_cache_task
|
||||
from orgs.models import Organization
|
||||
|
||||
|
||||
class OrgRelatedCache(Cache):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.current_org = Organization.get_instance(current_org.id)
|
||||
|
||||
def get_current_org(self):
|
||||
"""
|
||||
暴露给子类控制组织的回调
|
||||
1. 在交互式环境下能控制组织
|
||||
2. 在 celery 任务下能控制组织
|
||||
"""
|
||||
return self.current_org
|
||||
|
||||
def refresh(self, *fields):
|
||||
with tmp_to_org(self.get_current_org()):
|
||||
return super().refresh(*fields)
|
||||
|
||||
def refresh_async(self, *fields):
|
||||
"""
|
||||
在事务提交之后再发送信号,防止因事务的隔离性导致未获得最新的数据
|
||||
"""
|
||||
def func():
|
||||
logger.info(f'CACHE: Send refresh task {self}.{fields}')
|
||||
refresh_org_cache_task.delay(self, *fields)
|
||||
on_commit(func)
|
|
@ -0,0 +1,46 @@
|
|||
from .cache import OrgRelatedCache, IntegerField
|
||||
from users.models import UserGroup, User
|
||||
from assets.models import Node, AdminUser, SystemUser, Domain, Gateway
|
||||
from applications.models import Application
|
||||
from perms.models import AssetPermission, ApplicationPermission
|
||||
from .models import OrganizationMember
|
||||
|
||||
|
||||
class OrgResourceStatisticsCache(OrgRelatedCache):
|
||||
users_amount = IntegerField()
|
||||
groups_amount = IntegerField(queryset=UserGroup.objects)
|
||||
|
||||
assets_amount = IntegerField()
|
||||
nodes_amount = IntegerField(queryset=Node.objects)
|
||||
admin_users_amount = IntegerField(queryset=AdminUser.objects)
|
||||
system_users_amount = IntegerField(queryset=SystemUser.objects)
|
||||
domains_amount = IntegerField(queryset=Domain.objects)
|
||||
gateways_amount = IntegerField(queryset=Gateway.objects)
|
||||
|
||||
applications_amount = IntegerField(queryset=Application.objects)
|
||||
|
||||
asset_perms_amount = IntegerField(queryset=AssetPermission.objects)
|
||||
app_perms_amount = IntegerField(queryset=ApplicationPermission.objects)
|
||||
|
||||
def __init__(self, org):
|
||||
super().__init__()
|
||||
self.org = org
|
||||
|
||||
def get_key_suffix(self):
|
||||
return f'<org:{self.org.id}>'
|
||||
|
||||
def get_current_org(self):
|
||||
return self.org
|
||||
|
||||
def compute_users_amount(self):
|
||||
if self.org.is_real():
|
||||
users_amount = OrganizationMember.objects.values(
|
||||
'user_id'
|
||||
).filter(org_id=self.org.id).distinct().count()
|
||||
else:
|
||||
users_amount = User.objects.all().distinct().count()
|
||||
return users_amount
|
||||
|
||||
def compute_assets_amount(self):
|
||||
node = Node.org_root()
|
||||
return node.assets_amount
|
|
@ -7,7 +7,7 @@ from django.db.models import signals
|
|||
from django.db.models import Q
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from common.utils import is_uuid
|
||||
from common.utils import is_uuid, lazyproperty
|
||||
from common.const import choices
|
||||
from common.db.models import ChoiceSet
|
||||
|
||||
|
@ -215,6 +215,33 @@ class Organization(models.Model):
|
|||
from .utils import set_current_org
|
||||
set_current_org(self)
|
||||
|
||||
@lazyproperty
|
||||
def resource_statistics_cache(self):
|
||||
from .caches import OrgResourceStatisticsCache
|
||||
return OrgResourceStatisticsCache(self)
|
||||
|
||||
def get_total_resources_amount(self):
|
||||
from django.apps import apps
|
||||
from orgs.mixins.models import OrgModelMixin
|
||||
summary = {'users.Members': self.members.all().count()}
|
||||
for app_name, app_config in apps.app_configs.items():
|
||||
models_cls = app_config.get_models()
|
||||
for model in models_cls:
|
||||
if not issubclass(model, OrgModelMixin):
|
||||
continue
|
||||
key = '{}.{}'.format(app_name, model.__name__)
|
||||
summary[key] = self.get_resource_amount(model)
|
||||
return summary
|
||||
|
||||
def get_resource_amount(self, resource_model):
|
||||
from .utils import tmp_to_org
|
||||
from .mixins.models import OrgModelMixin
|
||||
|
||||
if not issubclass(resource_model, OrgModelMixin):
|
||||
return 0
|
||||
with tmp_to_org(self):
|
||||
return resource_model.objects.all().count()
|
||||
|
||||
|
||||
def _convert_to_uuid_set(users):
|
||||
rst = set()
|
||||
|
|
|
@ -10,18 +10,37 @@ from common.db.models import concated_display as display
|
|||
from .models import Organization, OrganizationMember, ROLE
|
||||
|
||||
|
||||
class ResourceStatisticsSerializer(serializers.Serializer):
|
||||
users_amount = serializers.IntegerField(required=False)
|
||||
groups_amount = serializers.IntegerField(required=False)
|
||||
|
||||
assets_amount = serializers.IntegerField(required=False)
|
||||
nodes_amount = serializers.IntegerField(required=False)
|
||||
admin_users_amount = serializers.IntegerField(required=False)
|
||||
system_users_amount = serializers.IntegerField(required=False)
|
||||
domains_amount = serializers.IntegerField(required=False)
|
||||
gateways_amount = serializers.IntegerField(required=False)
|
||||
|
||||
applications_amount = serializers.IntegerField(required=False)
|
||||
asset_perms_amount = serializers.IntegerField(required=False)
|
||||
app_perms_amount = serializers.IntegerField(required=False)
|
||||
|
||||
|
||||
class OrgSerializer(ModelSerializer):
|
||||
users = serializers.PrimaryKeyRelatedField(many=True, queryset=User.objects.all(), write_only=True, required=False)
|
||||
admins = serializers.PrimaryKeyRelatedField(many=True, queryset=User.objects.all(), write_only=True, required=False)
|
||||
auditors = serializers.PrimaryKeyRelatedField(many=True, queryset=User.objects.all(), write_only=True, required=False)
|
||||
|
||||
resource_statistics = ResourceStatisticsSerializer(source='resource_statistics_cache')
|
||||
|
||||
class Meta:
|
||||
model = Organization
|
||||
list_serializer_class = AdaptedBulkListSerializer
|
||||
fields_mini = ['id', 'name']
|
||||
fields_small = fields_mini + [
|
||||
'created_by', 'date_created', 'comment'
|
||||
'created_by', 'date_created', 'comment', 'resource_statistics'
|
||||
]
|
||||
|
||||
fields_m2m = ['users', 'admins', 'auditors']
|
||||
fields = fields_small + fields_m2m
|
||||
read_only_fields = ['created_by', 'date_created']
|
||||
|
@ -60,6 +79,8 @@ class OrgMemberSerializer(BulkModelSerializer):
|
|||
class Meta:
|
||||
model = OrganizationMember
|
||||
fields = ('id', 'org', 'user', 'role', 'org_display', 'user_display', 'role_display')
|
||||
use_model_bulk_create = True
|
||||
model_bulk_create_kwargs = {'ignore_conflicts': True}
|
||||
|
||||
def get_unique_together_validators(self):
|
||||
if self.parent:
|
||||
|
|
|
@ -4,7 +4,7 @@ from collections import defaultdict
|
|||
from functools import partial
|
||||
|
||||
from django.db.models.signals import m2m_changed
|
||||
from django.db.models.signals import post_save
|
||||
from django.db.models.signals import post_save, pre_delete
|
||||
from django.dispatch import receiver
|
||||
|
||||
from orgs.utils import tmp_to_org
|
||||
|
@ -12,7 +12,10 @@ from .models import Organization, OrganizationMember
|
|||
from .hands import set_current_org, Node, get_current_org
|
||||
from perms.models import (AssetPermission, ApplicationPermission)
|
||||
from users.models import UserGroup, User
|
||||
from common.const.signals import PRE_REMOVE, POST_REMOVE
|
||||
from applications.models import Application
|
||||
from assets.models import Asset, AdminUser, SystemUser, Domain, Gateway
|
||||
from common.const.signals import PRE_REMOVE, POST_REMOVE, POST_PREFIX
|
||||
from .caches import OrgResourceStatisticsCache
|
||||
|
||||
|
||||
@receiver(post_save, sender=Organization)
|
||||
|
@ -106,3 +109,72 @@ def on_org_user_changed(action, instance, reverse, pk_set, **kwargs):
|
|||
|
||||
leaved_users = set(pk_set) - set(org.members.filter(id__in=user_pk_set).values_list('id', flat=True))
|
||||
_clear_users_from_org(org, leaved_users)
|
||||
|
||||
|
||||
# 缓存相关
|
||||
# -----------------------------------------------------
|
||||
|
||||
def refresh_user_amount_on_user_create_or_delete(user_id):
|
||||
orgs = Organization.objects.filter(m2m_org_members__user_id=user_id).distinct()
|
||||
for org in orgs:
|
||||
org_cache = OrgResourceStatisticsCache(org)
|
||||
org_cache.refresh_async('users_amount')
|
||||
|
||||
|
||||
@receiver(post_save, sender=User)
|
||||
def on_user_create(sender, instance, created, **kwargs):
|
||||
if created:
|
||||
refresh_user_amount_on_user_create_or_delete(instance.id)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=User)
|
||||
def on_user_delete(sender, instance, **kwargs):
|
||||
refresh_user_amount_on_user_create_or_delete(instance.id)
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=OrganizationMember)
|
||||
def on_org_user_changed(sender, action, instance, reverse, pk_set, **kwargs):
|
||||
if not action.startswith(POST_PREFIX):
|
||||
return
|
||||
|
||||
if reverse:
|
||||
orgs = Organization.objects.filter(id__in=pk_set)
|
||||
else:
|
||||
orgs = [instance]
|
||||
|
||||
for org in orgs:
|
||||
org_cache = OrgResourceStatisticsCache(org)
|
||||
org_cache.refresh_async('users_amount')
|
||||
|
||||
|
||||
class OrgResourceStatisticsRefreshUtil:
|
||||
model_cache_field_mapper = {
|
||||
ApplicationPermission: 'app_perms_amount',
|
||||
AssetPermission: 'asset_perms_amount',
|
||||
Application: 'applications_amount',
|
||||
Gateway: 'gateways_amount',
|
||||
Domain: 'domains_amount',
|
||||
SystemUser: 'system_users_amount',
|
||||
AdminUser: 'admin_users_amount',
|
||||
Node: 'nodes_amount',
|
||||
Asset: 'assets_amount',
|
||||
UserGroup: 'groups_amount',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def refresh_if_need(cls, instance):
|
||||
cache_field_name = cls.model_cache_field_mapper.get(type(instance))
|
||||
if cache_field_name:
|
||||
org_cache = OrgResourceStatisticsCache(instance.org)
|
||||
org_cache.refresh_async(cache_field_name)
|
||||
|
||||
|
||||
@receiver(post_save)
|
||||
def on_post_save_refresh_org_resource_statistics_cache(sender, instance, created, **kwargs):
|
||||
if created:
|
||||
OrgResourceStatisticsRefreshUtil.refresh_if_need(instance)
|
||||
|
||||
|
||||
@receiver(pre_delete)
|
||||
def on_pre_delete_refresh_org_resource_statistics_cache(sender, instance, **kwargs):
|
||||
OrgResourceStatisticsRefreshUtil.refresh_if_need(instance)
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from celery import shared_task
|
||||
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@shared_task
|
||||
def refresh_org_cache_task(cache, *fields):
|
||||
logger.info(f'CACHE: refresh <org: {cache.get_current_org()}> {cache.key}.{fields}')
|
||||
cache.refresh(*fields)
|
|
@ -18,7 +18,7 @@ from django.shortcuts import reverse
|
|||
|
||||
from common.local import LOCAL_DYNAMIC_SETTINGS
|
||||
from orgs.utils import current_org
|
||||
from orgs.models import OrganizationMember
|
||||
from orgs.models import OrganizationMember, Organization
|
||||
from common.utils import date_expired_default, get_logger, lazyproperty
|
||||
from common import fields
|
||||
from common.const import choices
|
||||
|
@ -327,7 +327,8 @@ class RoleMixin:
|
|||
def remove(self):
|
||||
if not current_org.is_real():
|
||||
return
|
||||
OrganizationMember.objects.remove_users(current_org, [self])
|
||||
org = Organization.get_instance(current_org.id)
|
||||
OrganizationMember.objects.remove_users(org, [self])
|
||||
|
||||
@classmethod
|
||||
def get_super_admins(cls):
|
||||
|
|
Loading…
Reference in New Issue