mirror of https://github.com/jumpserver/jumpserver
添加 UnionQuertSet (#5578)
* 添加 UnionQuertSet * 跑通了 * 改变了 count 这类方法的代理模式 * 使用了老广的 Co-authored-by: xinwen <coderWen@126.com>pull/5605/head
parent
50e6c96358
commit
501ad698b7
|
@ -273,3 +273,7 @@ class Time:
|
|||
for timestamp, msg in zip(timestamps, self._msgs):
|
||||
logger.debug(f'TIME_IT: {msg} {timestamp-last}')
|
||||
last = timestamp
|
||||
|
||||
|
||||
def isinstance_method(attr):
|
||||
return isinstance(attr, type(Time().time))
|
||||
|
|
|
@ -34,12 +34,12 @@ class UserAllGrantedAssetsQuerysetMixin:
|
|||
pagination_class = AllGrantedAssetPagination
|
||||
user: User
|
||||
|
||||
def get_union_queryset(self, qs_stage: QuerySetStage):
|
||||
def get_queryset(self):
|
||||
if getattr(self, 'swagger_fake_view', False):
|
||||
return Asset.objects.none()
|
||||
qs_stage.prefetch_related('platform').only(*self.only_fields)
|
||||
queryset = UserGrantedAssetsQueryUtils(self.user) \
|
||||
.get_all_granted_assets(qs_stage)
|
||||
.get_all_granted_assets()
|
||||
queryset = queryset.prefetch_related('platform').only(*self.only_fields)
|
||||
return queryset
|
||||
|
||||
|
||||
|
@ -47,13 +47,13 @@ class UserFavoriteGrantedAssetsMixin:
|
|||
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
|
||||
user: User
|
||||
|
||||
def get_union_queryset(self, qs_stage: QuerySetStage):
|
||||
def get_queryset(self):
|
||||
if getattr(self, 'swagger_fake_view', False):
|
||||
return Asset.objects.none()
|
||||
user = self.user
|
||||
qs_stage.prefetch_related('platform').only(*self.only_fields)
|
||||
utils = UserGrantedAssetsQueryUtils(user)
|
||||
assets = utils.get_favorite_assets(qs_stage=qs_stage)
|
||||
assets = utils.get_favorite_assets()
|
||||
assets = assets.prefetch_related('platform').only(*self.only_fields)
|
||||
return assets
|
||||
|
||||
|
||||
|
@ -63,58 +63,35 @@ class UserGrantedNodeAssetsMixin:
|
|||
pagination_node: Node
|
||||
user: User
|
||||
|
||||
def get_union_queryset(self, qs_stage: QuerySetStage):
|
||||
def get_queryset(self):
|
||||
if getattr(self, 'swagger_fake_view', False):
|
||||
return Asset.objects.none()
|
||||
node_id = self.kwargs.get("node_id")
|
||||
qs_stage.prefetch_related('platform').only(*self.only_fields)
|
||||
|
||||
node, assets = UserGrantedAssetsQueryUtils(self.user).get_node_all_assets(
|
||||
node_id, qs_stage=qs_stage
|
||||
node_id
|
||||
)
|
||||
assets = assets.prefetch_related('platform').only(*self.only_fields)
|
||||
self.pagination_node = node
|
||||
return assets
|
||||
|
||||
|
||||
# 控制格式的 ----------------------------------------------------
|
||||
|
||||
class AssetsUnionQuerysetMixin:
|
||||
def get_queryset_union_prefer(self):
|
||||
if hasattr(self, 'get_union_queryset'):
|
||||
# 为了支持 union 查询
|
||||
queryset = Asset.objects.all().distinct()
|
||||
queryset = self.filter_queryset(queryset)
|
||||
qs_stage = QuerySetStage()
|
||||
qs_stage.and_with_queryset(queryset)
|
||||
queryset = self.get_union_queryset(qs_stage)
|
||||
else:
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
return queryset
|
||||
|
||||
|
||||
class AssetsSerializerFormatMixin(AssetsUnionQuerysetMixin):
|
||||
class AssetsSerializerFormatMixin:
|
||||
serializer_class = serializers.AssetGrantedSerializer
|
||||
filterset_fields = ['hostname', 'ip', 'id', 'comment']
|
||||
search_fields = ['hostname', 'ip', 'comment']
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
queryset = self.get_queryset_union_prefer()
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class AssetsTreeFormatMixin(AssetsUnionQuerysetMixin, SerializeToTreeNodeMixin):
|
||||
class AssetsTreeFormatMixin(SerializeToTreeNodeMixin):
|
||||
"""
|
||||
将 资产 序列化成树的结构返回
|
||||
"""
|
||||
|
||||
def list(self, request: Request, *args, **kwargs):
|
||||
queryset = self.get_queryset_union_prefer()
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
if request.query_params.get('search'):
|
||||
# 如果用户搜索的条件不精准,会导致返回大量的无意义数据。
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from collections import defaultdict
|
||||
from typing import List, Tuple
|
||||
from functools import reduce, partial
|
||||
from common.utils import isinstance_method
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
|
@ -51,6 +53,81 @@ def get_user_all_asset_perm_ids(user) -> set:
|
|||
return asset_perm_ids
|
||||
|
||||
|
||||
class UnionQuerySet(QuerySet):
|
||||
after_union = ['order_by']
|
||||
not_return_qs = [
|
||||
'query', 'get', 'create', 'get_or_create',
|
||||
'update_or_create', 'bulk_create', 'count',
|
||||
'latest', 'earliest', 'first', 'last', 'aggregate',
|
||||
'exists', 'update', 'delete', 'as_manager', 'explain',
|
||||
]
|
||||
|
||||
def __init__(self, *queryset_list):
|
||||
self.queryset_list = queryset_list
|
||||
self.after_union_items = []
|
||||
self.before_union_items = []
|
||||
|
||||
def __execute(self):
|
||||
queryset_list = []
|
||||
for qs in self.queryset_list:
|
||||
for attr, args, kwargs in self.before_union_items:
|
||||
qs = getattr(qs, attr)(*args, **kwargs)
|
||||
queryset_list.append(qs)
|
||||
union_qs = reduce(lambda x, y: x.union(y), queryset_list)
|
||||
for attr, args, kwargs in self.after_union_items:
|
||||
union_qs = getattr(union_qs, attr)(*args, **kwargs)
|
||||
return union_qs
|
||||
|
||||
def __before_union_perform(self, item, *args, **kwargs):
|
||||
self.before_union_items.append((item, args, kwargs))
|
||||
return self.__clone(*self.queryset_list)
|
||||
|
||||
def __after_union_perform(self, item, *args, **kwargs):
|
||||
self.after_union_items.append((item, args, kwargs))
|
||||
return self.__clone(*self.queryset_list)
|
||||
|
||||
def __clone(self, *queryset_list):
|
||||
uqs = UnionQuerySet(*queryset_list)
|
||||
uqs.after_union_items = self.after_union_items
|
||||
uqs.before_union_items = self.before_union_items
|
||||
return uqs
|
||||
|
||||
def __getattribute__(self, item):
|
||||
if item.startswith('__') or item in UnionQuerySet.__dict__ or item in [
|
||||
'queryset_list', 'after_union_items', 'before_union_items'
|
||||
]:
|
||||
return object.__getattribute__(self, item)
|
||||
|
||||
if item in UnionQuerySet.not_return_qs:
|
||||
return getattr(self.__execute(), item)
|
||||
|
||||
origin_item = object.__getattribute__(self, 'queryset_list')[0]
|
||||
origin_attr = getattr(origin_item, item, None)
|
||||
if not isinstance_method(origin_attr):
|
||||
return getattr(self.__execute(), item)
|
||||
|
||||
if item in UnionQuerySet.after_union:
|
||||
attr = partial(self.__after_union_perform, item)
|
||||
else:
|
||||
attr = partial(self.__before_union_perform, item)
|
||||
return attr
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__execute()[item]
|
||||
|
||||
def __next__(self):
|
||||
return next(self.__execute())
|
||||
|
||||
@classmethod
|
||||
def test_it(cls):
|
||||
from assets.models import Asset
|
||||
assets1 = Asset.objects.filter(hostname__startswith='a')
|
||||
assets2 = Asset.objects.filter(hostname__startswith='b')
|
||||
|
||||
qs = cls(assets1, assets2)
|
||||
return qs
|
||||
|
||||
|
||||
class QuerySetStage:
|
||||
def __init__(self):
|
||||
self._prefetch_related = set()
|
||||
|
@ -541,14 +618,13 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
|
|||
|
||||
class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
|
||||
|
||||
def get_favorite_assets(self, qs_stage: QuerySetStage = None, only=('id', )) -> AssetQuerySet:
|
||||
def get_favorite_assets(self, only=('id', )) -> QuerySet:
|
||||
favorite_asset_ids = FavoriteAsset.objects.filter(
|
||||
user=self.user
|
||||
).values_list('asset_id', flat=True)
|
||||
favorite_asset_ids = list(favorite_asset_ids)
|
||||
qs_stage = qs_stage or QuerySetStage()
|
||||
qs_stage.filter(id__in=favorite_asset_ids).only(*only)
|
||||
assets = self.get_all_granted_assets(qs_stage)
|
||||
assets = self.get_all_granted_assets()
|
||||
assets = assets.filter(id__in=favorite_asset_ids).only(*only)
|
||||
return assets
|
||||
|
||||
def get_ungroup_assets(self) -> AssetQuerySet:
|
||||
|
@ -560,39 +636,30 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
|
|||
).distinct()
|
||||
return queryset
|
||||
|
||||
def get_direct_granted_nodes_assets(self, qs_stage: QuerySetStage = None) -> AssetQuerySet:
|
||||
def get_direct_granted_nodes_assets(self) -> AssetQuerySet:
|
||||
granted_node_ids = AssetPermission.nodes.through.objects.filter(
|
||||
assetpermission_id__in=self.asset_perm_ids
|
||||
).values_list('node_id', flat=True).distinct()
|
||||
granted_node_ids = list(granted_node_ids)
|
||||
granted_nodes = PermNode.objects.filter(id__in=granted_node_ids).only('id', 'key')
|
||||
queryset = PermNode.get_nodes_all_assets(*granted_nodes)
|
||||
if qs_stage:
|
||||
queryset = qs_stage.merge(queryset)
|
||||
return queryset
|
||||
|
||||
def get_all_granted_assets(self, qs_stage: QuerySetStage = None) -> AssetQuerySet:
|
||||
def get_all_granted_assets(self) -> QuerySet:
|
||||
nodes_assets = self.get_direct_granted_nodes_assets()
|
||||
assets = self.get_direct_granted_assets()
|
||||
|
||||
if qs_stage:
|
||||
nodes_assets, assets = qs_stage.merge_multi_before_union(nodes_assets, assets)
|
||||
queryset = nodes_assets.union(assets)
|
||||
if qs_stage:
|
||||
queryset = qs_stage.merge_after_union(queryset)
|
||||
queryset = UnionQuerySet(nodes_assets, assets)
|
||||
return queryset
|
||||
|
||||
def get_node_all_assets(self, id, qs_stage: QuerySetStage = None) -> Tuple[PermNode, QuerySet]:
|
||||
def get_node_all_assets(self, id) -> Tuple[PermNode, QuerySet]:
|
||||
node = PermNode.objects.get(id=id)
|
||||
granted_status = node.get_granted_status(self.user)
|
||||
if granted_status == NodeFrom.granted:
|
||||
assets = PermNode.get_nodes_all_assets(node)
|
||||
if qs_stage:
|
||||
assets = qs_stage.merge(assets)
|
||||
return node, assets
|
||||
elif granted_status in (NodeFrom.asset, NodeFrom.child):
|
||||
node.use_granted_assets_amount()
|
||||
assets = self._get_indirect_granted_node_all_assets(node, qs_stage=qs_stage)
|
||||
assets = self._get_indirect_granted_node_all_assets(node)
|
||||
return node, assets
|
||||
else:
|
||||
node.assets_amount = 0
|
||||
|
@ -614,7 +681,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
|
|||
assets = Asset.objects.order_by().filter(nodes_id=id) & self.get_direct_granted_assets()
|
||||
return assets
|
||||
|
||||
def _get_indirect_granted_node_all_assets(self, node, qs_stage: QuerySetStage = None) -> QuerySet:
|
||||
def _get_indirect_granted_node_all_assets(self, node) -> QuerySet:
|
||||
"""
|
||||
此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询
|
||||
1. 查询该节点下的直接授权节点
|
||||
|
@ -645,10 +712,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
|
|||
nodes__id__in=only_asset_granted_node_ids,
|
||||
granted_by_permissions__id__in=self.asset_perm_ids
|
||||
).distinct().order_by()
|
||||
if qs_stage:
|
||||
node_assets, assets = qs_stage.merge_multi_before_union(node_assets, assets)
|
||||
granted_assets = node_assets.union(assets)
|
||||
granted_assets = qs_stage.merge_after_union(granted_assets)
|
||||
granted_assets = UnionQuerySet(node_assets, assets)
|
||||
return granted_assets
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue