diff --git a/apps/perms/utils/permission.py b/apps/perms/utils/permission.py index 89407be14..849bec02b 100644 --- a/apps/perms/utils/permission.py +++ b/apps/perms/utils/permission.py @@ -1,3 +1,4 @@ +import django from django.db.models import QuerySet, Model from collections.abc import Iterable from assets.models import Node, Asset @@ -36,18 +37,18 @@ class AssetPermissionUtil(object): group_ids = [g.id for g in user_groups] else: group_ids = user_groups.values_list('id', flat=True).distinct() - group_perm_ids = AssetPermission.user_groups.through.objects \ + perm_ids = AssetPermission.user_groups.through.objects \ .filter(usergroup_id__in=group_ids) \ .values_list('assetpermission_id', flat=True).distinct() if flat: - return group_perm_ids - perms = self.get_permissions(ids=group_perm_ids) + return perm_ids + perms = self.get_permissions(ids=perm_ids) return perms def get_permissions_for_assets(self, assets, with_node=True, flat=False): """ 获取资产的授权规则""" perm_ids = set() - assets = self.transform_to_queryset(assets, Asset) + assets = self.convert_to_queryset_if_need(assets, Asset) asset_ids = [str(a.id) for a in assets] relations = AssetPermission.assets.through.objects.filter(asset_id__in=asset_ids) asset_perm_ids = relations.values_list('assetpermission_id', flat=True).distinct() @@ -63,7 +64,7 @@ class AssetPermissionUtil(object): def get_permissions_for_nodes(self, nodes, with_ancestor=False, flat=False): """ 获取节点的授权规则 """ - nodes = self.transform_to_queryset(nodes, Node) + nodes = self.convert_to_queryset_if_need(nodes, Node) if with_ancestor: nodes = Node.get_ancestor_queryset(nodes) node_ids = nodes.values_list('id', flat=True).distinct() @@ -95,12 +96,15 @@ class AssetPermissionUtil(object): return perms @staticmethod - def transform_to_queryset(objs_or_ids, model): + def convert_to_queryset_if_need(objs_or_ids, model): if not objs_or_ids: return objs_or_ids - if isinstance(objs_or_ids, QuerySet): + if isinstance(objs_or_ids, QuerySet) and isinstance(objs_or_ids.first(), model): return objs_or_ids - ids = [str(o.id) if isinstance(o, model) else o for o in objs_or_ids] + ids = [ + str(i.id) if isinstance(i, model) else i + for i in objs_or_ids + ] return model.objects.filter(id__in=ids)