添加 UnionQuertSet (#5578)

* 添加 UnionQuertSet

* 跑通了

* 改变了 count 这类方法的代理模式

* 使用了老广的

Co-authored-by: xinwen <coderWen@126.com>
pull/5605/head
fit2bot 2021-02-07 10:15:39 +08:00 committed by GitHub
parent 50e6c96358
commit 501ad698b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 59 deletions

View File

@ -273,3 +273,7 @@ class Time:
for timestamp, msg in zip(timestamps, self._msgs):
logger.debug(f'TIME_IT: {msg} {timestamp-last}')
last = timestamp
def isinstance_method(attr):
return isinstance(attr, type(Time().time))

View File

@ -34,12 +34,12 @@ class UserAllGrantedAssetsQuerysetMixin:
pagination_class = AllGrantedAssetPagination
user: User
def get_union_queryset(self, qs_stage: QuerySetStage):
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
qs_stage.prefetch_related('platform').only(*self.only_fields)
queryset = UserGrantedAssetsQueryUtils(self.user) \
.get_all_granted_assets(qs_stage)
.get_all_granted_assets()
queryset = queryset.prefetch_related('platform').only(*self.only_fields)
return queryset
@ -47,13 +47,13 @@ class UserFavoriteGrantedAssetsMixin:
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
user: User
def get_union_queryset(self, qs_stage: QuerySetStage):
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
user = self.user
qs_stage.prefetch_related('platform').only(*self.only_fields)
utils = UserGrantedAssetsQueryUtils(user)
assets = utils.get_favorite_assets(qs_stage=qs_stage)
assets = utils.get_favorite_assets()
assets = assets.prefetch_related('platform').only(*self.only_fields)
return assets
@ -63,58 +63,35 @@ class UserGrantedNodeAssetsMixin:
pagination_node: Node
user: User
def get_union_queryset(self, qs_stage: QuerySetStage):
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
node_id = self.kwargs.get("node_id")
qs_stage.prefetch_related('platform').only(*self.only_fields)
node, assets = UserGrantedAssetsQueryUtils(self.user).get_node_all_assets(
node_id, qs_stage=qs_stage
node_id
)
assets = assets.prefetch_related('platform').only(*self.only_fields)
self.pagination_node = node
return assets
# 控制格式的 ----------------------------------------------------
class AssetsUnionQuerysetMixin:
def get_queryset_union_prefer(self):
if hasattr(self, 'get_union_queryset'):
# 为了支持 union 查询
queryset = Asset.objects.all().distinct()
queryset = self.filter_queryset(queryset)
qs_stage = QuerySetStage()
qs_stage.and_with_queryset(queryset)
queryset = self.get_union_queryset(qs_stage)
else:
queryset = self.filter_queryset(self.get_queryset())
return queryset
class AssetsSerializerFormatMixin(AssetsUnionQuerysetMixin):
class AssetsSerializerFormatMixin:
serializer_class = serializers.AssetGrantedSerializer
filterset_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def list(self, request, *args, **kwargs):
queryset = self.get_queryset_union_prefer()
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
class AssetsTreeFormatMixin(AssetsUnionQuerysetMixin, SerializeToTreeNodeMixin):
class AssetsTreeFormatMixin(SerializeToTreeNodeMixin):
"""
资产 序列化成树的结构返回
"""
def list(self, request: Request, *args, **kwargs):
queryset = self.get_queryset_union_prefer()
queryset = self.filter_queryset(self.get_queryset())
if request.query_params.get('search'):
# 如果用户搜索的条件不精准,会导致返回大量的无意义数据。

View File

@ -1,5 +1,7 @@
from collections import defaultdict
from typing import List, Tuple
from functools import reduce, partial
from common.utils import isinstance_method
from django.core.cache import cache
from django.conf import settings
@ -51,6 +53,81 @@ def get_user_all_asset_perm_ids(user) -> set:
return asset_perm_ids
class UnionQuerySet(QuerySet):
after_union = ['order_by']
not_return_qs = [
'query', 'get', 'create', 'get_or_create',
'update_or_create', 'bulk_create', 'count',
'latest', 'earliest', 'first', 'last', 'aggregate',
'exists', 'update', 'delete', 'as_manager', 'explain',
]
def __init__(self, *queryset_list):
self.queryset_list = queryset_list
self.after_union_items = []
self.before_union_items = []
def __execute(self):
queryset_list = []
for qs in self.queryset_list:
for attr, args, kwargs in self.before_union_items:
qs = getattr(qs, attr)(*args, **kwargs)
queryset_list.append(qs)
union_qs = reduce(lambda x, y: x.union(y), queryset_list)
for attr, args, kwargs in self.after_union_items:
union_qs = getattr(union_qs, attr)(*args, **kwargs)
return union_qs
def __before_union_perform(self, item, *args, **kwargs):
self.before_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __after_union_perform(self, item, *args, **kwargs):
self.after_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __clone(self, *queryset_list):
uqs = UnionQuerySet(*queryset_list)
uqs.after_union_items = self.after_union_items
uqs.before_union_items = self.before_union_items
return uqs
def __getattribute__(self, item):
if item.startswith('__') or item in UnionQuerySet.__dict__ or item in [
'queryset_list', 'after_union_items', 'before_union_items'
]:
return object.__getattribute__(self, item)
if item in UnionQuerySet.not_return_qs:
return getattr(self.__execute(), item)
origin_item = object.__getattribute__(self, 'queryset_list')[0]
origin_attr = getattr(origin_item, item, None)
if not isinstance_method(origin_attr):
return getattr(self.__execute(), item)
if item in UnionQuerySet.after_union:
attr = partial(self.__after_union_perform, item)
else:
attr = partial(self.__before_union_perform, item)
return attr
def __getitem__(self, item):
return self.__execute()[item]
def __next__(self):
return next(self.__execute())
@classmethod
def test_it(cls):
from assets.models import Asset
assets1 = Asset.objects.filter(hostname__startswith='a')
assets2 = Asset.objects.filter(hostname__startswith='b')
qs = cls(assets1, assets2)
return qs
class QuerySetStage:
def __init__(self):
self._prefetch_related = set()
@ -541,14 +618,13 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
def get_favorite_assets(self, qs_stage: QuerySetStage = None, only=('id', )) -> AssetQuerySet:
def get_favorite_assets(self, only=('id', )) -> QuerySet:
favorite_asset_ids = FavoriteAsset.objects.filter(
user=self.user
).values_list('asset_id', flat=True)
favorite_asset_ids = list(favorite_asset_ids)
qs_stage = qs_stage or QuerySetStage()
qs_stage.filter(id__in=favorite_asset_ids).only(*only)
assets = self.get_all_granted_assets(qs_stage)
assets = self.get_all_granted_assets()
assets = assets.filter(id__in=favorite_asset_ids).only(*only)
return assets
def get_ungroup_assets(self) -> AssetQuerySet:
@ -560,39 +636,30 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
).distinct()
return queryset
def get_direct_granted_nodes_assets(self, qs_stage: QuerySetStage = None) -> AssetQuerySet:
def get_direct_granted_nodes_assets(self) -> AssetQuerySet:
granted_node_ids = AssetPermission.nodes.through.objects.filter(
assetpermission_id__in=self.asset_perm_ids
).values_list('node_id', flat=True).distinct()
granted_node_ids = list(granted_node_ids)
granted_nodes = PermNode.objects.filter(id__in=granted_node_ids).only('id', 'key')
queryset = PermNode.get_nodes_all_assets(*granted_nodes)
if qs_stage:
queryset = qs_stage.merge(queryset)
return queryset
def get_all_granted_assets(self, qs_stage: QuerySetStage = None) -> AssetQuerySet:
def get_all_granted_assets(self) -> QuerySet:
nodes_assets = self.get_direct_granted_nodes_assets()
assets = self.get_direct_granted_assets()
if qs_stage:
nodes_assets, assets = qs_stage.merge_multi_before_union(nodes_assets, assets)
queryset = nodes_assets.union(assets)
if qs_stage:
queryset = qs_stage.merge_after_union(queryset)
queryset = UnionQuerySet(nodes_assets, assets)
return queryset
def get_node_all_assets(self, id, qs_stage: QuerySetStage = None) -> Tuple[PermNode, QuerySet]:
def get_node_all_assets(self, id) -> Tuple[PermNode, QuerySet]:
node = PermNode.objects.get(id=id)
granted_status = node.get_granted_status(self.user)
if granted_status == NodeFrom.granted:
assets = PermNode.get_nodes_all_assets(node)
if qs_stage:
assets = qs_stage.merge(assets)
return node, assets
elif granted_status in (NodeFrom.asset, NodeFrom.child):
node.use_granted_assets_amount()
assets = self._get_indirect_granted_node_all_assets(node, qs_stage=qs_stage)
assets = self._get_indirect_granted_node_all_assets(node)
return node, assets
else:
node.assets_amount = 0
@ -614,7 +681,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
assets = Asset.objects.order_by().filter(nodes_id=id) & self.get_direct_granted_assets()
return assets
def _get_indirect_granted_node_all_assets(self, node, qs_stage: QuerySetStage = None) -> QuerySet:
def _get_indirect_granted_node_all_assets(self, node) -> QuerySet:
"""
此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询
1. 查询该节点下的直接授权节点
@ -645,10 +712,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
nodes__id__in=only_asset_granted_node_ids,
granted_by_permissions__id__in=self.asset_perm_ids
).distinct().order_by()
if qs_stage:
node_assets, assets = qs_stage.merge_multi_before_union(node_assets, assets)
granted_assets = node_assets.union(assets)
granted_assets = qs_stage.merge_after_union(granted_assets)
granted_assets = UnionQuerySet(node_assets, assets)
return granted_assets