fix: 修复授权树一些问题

pull/5663/head
xinwen 2021-02-25 14:45:21 +08:00 committed by 老广
parent 5de5fa2e96
commit 1036d1c132
13 changed files with 214 additions and 182 deletions

View File

@ -337,9 +337,9 @@ class NodeAllAssetsMappingMixin:
t1 = time.time() t1 = time.time()
with tmp_to_org(org_id): with tmp_to_org(org_id):
nodes_id_key = Node.objects.filter(org_id=org_id) \ nodes_id_key = Node.objects.annotate(
.annotate(char_id=output_as_string('id')) \ char_id=output_as_string('id')
.values_list('char_id', 'key') ).values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢) # * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_assets_id = Asset.nodes.through.objects.all() \ nodes_assets_id = Asset.nodes.through.objects.all() \

View File

@ -10,8 +10,11 @@
""" """
import uuid import uuid
from functools import reduce, partial
import inspect
from django.db.models import * from django.db.models import *
from django.db.models import QuerySet
from django.db.models.functions import Concat from django.db.models.functions import Concat
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -86,3 +89,84 @@ def concated_display(name1, name2):
def output_as_string(field_name): def output_as_string(field_name):
return ExpressionWrapper(F(field_name), output_field=CharField()) return ExpressionWrapper(F(field_name), output_field=CharField())
class UnionQuerySet(QuerySet):
after_union = ['order_by']
not_return_qs = [
'query', 'get', 'create', 'get_or_create',
'update_or_create', 'bulk_create', 'count',
'latest', 'earliest', 'first', 'last', 'aggregate',
'exists', 'update', 'delete', 'as_manager', 'explain',
]
def __init__(self, *queryset_list):
self.queryset_list = queryset_list
self.after_union_items = []
self.before_union_items = []
def __execute(self):
queryset_list = []
for qs in self.queryset_list:
for attr, args, kwargs in self.before_union_items:
qs = getattr(qs, attr)(*args, **kwargs)
queryset_list.append(qs)
union_qs = reduce(lambda x, y: x.union(y), queryset_list)
for attr, args, kwargs in self.after_union_items:
union_qs = getattr(union_qs, attr)(*args, **kwargs)
return union_qs
def __before_union_perform(self, item, *args, **kwargs):
self.before_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __after_union_perform(self, item, *args, **kwargs):
self.after_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __clone(self, *queryset_list):
uqs = UnionQuerySet(*queryset_list)
uqs.after_union_items = self.after_union_items
uqs.before_union_items = self.before_union_items
return uqs
def __getattribute__(self, item):
if item.startswith('__') or item in UnionQuerySet.__dict__ or item in [
'queryset_list', 'after_union_items', 'before_union_items'
]:
return object.__getattribute__(self, item)
if item in UnionQuerySet.not_return_qs:
return getattr(self.__execute(), item)
origin_item = object.__getattribute__(self, 'queryset_list')[0]
origin_attr = getattr(origin_item, item, None)
if not inspect.ismethod(origin_attr):
return getattr(self.__execute(), item)
if item in UnionQuerySet.after_union:
attr = partial(self.__after_union_perform, item)
else:
attr = partial(self.__before_union_perform, item)
return attr
def __getitem__(self, item):
return self.__execute()[item]
def __iter__(self):
return iter(self.__execute())
def __str__(self):
return str(self.__execute())
def __repr__(self):
return repr(self.__execute())
@classmethod
def test_it(cls):
from assets.models import Asset
assets1 = Asset.objects.filter(hostname__startswith='a')
assets2 = Asset.objects.filter(hostname__startswith='b')
qs = cls(assets1, assets2)
return qs

View File

@ -277,7 +277,3 @@ def bulk_get(d, *keys, default=None):
for key in keys: for key in keys:
values.append(d.get(key, default)) values.append(d.get(key, default))
return values return values
def isinstance_method(attr):
return isinstance(attr, type(Time().time))

View File

@ -93,21 +93,21 @@ class DistributedLock(RedisLock):
if self.locked_by_current_thread(): if self.locked_by_current_thread():
self._acquired_reentrant_lock = True self._acquired_reentrant_lock = True
logger.debug( logger.debug(
f'I[{self.id}] reentry lock[{self.name}] in thread[{self._thread_id}].') f'Reentry lock ok: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
return True return True
logger.debug(f'I[{self.id}] attempt acquire reentrant-lock[{self.name}].') logger.debug(f'Attempt acquire reentrant-lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
acquired = super().acquire(blocking=blocking, timeout=timeout) acquired = super().acquire(blocking=blocking, timeout=timeout)
if acquired: if acquired:
logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] now.') logger.debug(f'Acquired reentrant-lock ok: lock_id={self.id} lock={self.name} thread={self._thread_id}')
setattr(thread_local, self.name, self.id) setattr(thread_local, self.name, self.id)
else: else:
logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] failed.') logger.debug(f'Acquired reentrant-lock failed: lock_id={self.id} lock={self.name} thread={self._thread_id}')
return acquired return acquired
else: else:
logger.debug(f'I[{self.id}] attempt acquire lock[{self.name}].') logger.debug(f'Attempt acquire lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
acquired = super().acquire(blocking=blocking, timeout=timeout) acquired = super().acquire(blocking=blocking, timeout=timeout)
logger.debug(f'I[{self.id}] acquired lock[{self.name}] {acquired}.') logger.debug(f'Acquired lock: ok={acquired} lock_id={self.id} lock={self.name} thread={self._thread_id}')
return acquired return acquired
@property @property
@ -126,17 +126,17 @@ class DistributedLock(RedisLock):
def _release_on_reentrant_locked_by_brother(self): def _release_on_reentrant_locked_by_brother(self):
if self._acquired_reentrant_lock: if self._acquired_reentrant_lock:
self._acquired_reentrant_lock = False self._acquired_reentrant_lock = False
logger.debug(f'I[{self.id}] released reentrant-lock[{self.name}] owner[{self.get_owner_id()}] in thread[{self._thread_id}]') logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
return return
else: else:
self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired by me[{self.id}].') self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
def _release_on_reentrant_locked_by_me(self): def _release_on_reentrant_locked_by_me(self):
logger.debug(f'I[{self.id}] release reentrant-lock[{self.name}] in thread[{self._thread_id}]') logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name} thread={self._thread_id}')
id = getattr(thread_local, self.name, None) id = getattr(thread_local, self.name, None)
if id != self.id: if id != self.id:
raise PermissionError(f'Reentrant-lock[{self.name}] is not locked by me[{self.id}], owner[{id}]') raise PermissionError(f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
try: try:
# 这里要保证先删除 thread_local 的标记, # 这里要保证先删除 thread_local 的标记,
delattr(thread_local, self.name) delattr(thread_local, self.name)
@ -158,9 +158,9 @@ class DistributedLock(RedisLock):
def _release(self): def _release(self):
try: try:
self._release_redis_lock() self._release_redis_lock()
logger.debug(f'I[{self.id}] released lock[{self.name}]') logger.debug(f'Released lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
except NotAcquired as e: except NotAcquired as e:
logger.error(f'I[{self.id}] release lock[{self.name}] failed {e}') logger.error(f'Release lock failed: lock_id={self.id} lock={self.name} thread={self._thread_id} error: {e}')
self._raise_exc(e) self._raise_exc(e)
def release(self): def release(self):
@ -174,11 +174,11 @@ class DistributedLock(RedisLock):
else: else:
_release = self._release_on_reentrant_locked_by_brother _release = self._release_on_reentrant_locked_by_brother
else: else:
self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired in current-thread[{self._thread_id}]') self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} lock={self.name} thread={self._thread_id}')
# 处理是否在事务提交时才释放锁 # 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit: if self._release_on_transaction_commit:
logger.debug(f'I[{self.id}] release lock[{self.name}] on transaction commit ...') logger.debug(f'Release lock on transaction commit ... :lock_id={self.id} lock={self.name} thread={self._thread_id}')
transaction.on_commit(_release) transaction.on_commit(_release)
else: else:
_release() _release()

View File

@ -0,0 +1,18 @@
# Generated by Django 3.1 on 2021-02-26 07:36
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('orgs', '0009_auto_20201023_1628'),
]
operations = [
migrations.AlterField(
model_name='organizationmember',
name='role',
field=models.CharField(choices=[('Admin', 'Organization administrator'), ('Auditor', 'Organization auditor'), ('User', 'User')], db_index=True, default='User', max_length=16, verbose_name='Role'),
),
]

View File

@ -424,7 +424,7 @@ class OrganizationMember(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True) id = models.UUIDField(default=uuid.uuid4, primary_key=True)
org = models.ForeignKey(Organization, related_name='m2m_org_members', on_delete=models.CASCADE, verbose_name=_('Organization')) org = models.ForeignKey(Organization, related_name='m2m_org_members', on_delete=models.CASCADE, verbose_name=_('Organization'))
user = models.ForeignKey('users.User', related_name='m2m_org_members', on_delete=models.CASCADE, verbose_name=_('User')) user = models.ForeignKey('users.User', related_name='m2m_org_members', on_delete=models.CASCADE, verbose_name=_('User'))
role = models.CharField(max_length=16, choices=ROLE.choices, default=ROLE.USER, verbose_name=_("Role")) role = models.CharField(db_index=True, max_length=16, choices=ROLE.choices, default=ROLE.USER, verbose_name=_("Role"))
date_created = models.DateTimeField(auto_now_add=True, verbose_name=_("Date created")) date_created = models.DateTimeField(auto_now_add=True, verbose_name=_("Date created"))
date_updated = models.DateTimeField(auto_now=True, verbose_name=_("Date updated")) date_updated = models.DateTimeField(auto_now=True, verbose_name=_("Date updated"))
created_by = models.CharField(max_length=128, null=True, verbose_name=_('Created by')) created_by = models.CharField(max_length=128, null=True, verbose_name=_('Created by'))

View File

@ -13,6 +13,7 @@ from orgs.utils import current_org
from common.permissions import IsOrgAdmin from common.permissions import IsOrgAdmin
from perms import serializers from perms import serializers
from perms import models from perms import models
from perms.utils.asset.user_permission import UserGrantedAssetsQueryUtils
__all__ = [ __all__ = [
'AssetPermissionUserRelationViewSet', 'AssetPermissionUserGroupRelationViewSet', 'AssetPermissionUserRelationViewSet', 'AssetPermissionUserGroupRelationViewSet',
@ -103,15 +104,8 @@ class AssetPermissionAllAssetListApi(generics.ListAPIView):
def get_queryset(self): def get_queryset(self):
pk = self.kwargs.get("pk") pk = self.kwargs.get("pk")
perm = get_object_or_404(models.AssetPermission, pk=pk) query_utils = UserGrantedAssetsQueryUtils(None, asset_perm_ids=[pk])
assets = query_utils.get_all_granted_assets()
asset_q = Q(granted_by_permissions=perm)
granted_node_keys = Node.objects.filter(granted_by_permissions=perm).distinct().values_list('key', flat=True)
for key in granted_node_keys:
asset_q |= Q(nodes__key__startswith=f'{key}:')
asset_q |= Q(nodes__key=key)
assets = Asset.objects.filter(asset_q).only(*self.serializer_class.Meta.only_fields).distinct()
return assets return assets

View File

@ -13,7 +13,7 @@ from perms.utils.asset.user_permission import UserGrantedTreeRefreshController
class PermBaseMixin: class PermBaseMixin:
user: User user: User
def get(self, request, *args, **kwargs): def get(self, request: Request, *args, **kwargs):
force = is_true(request.query_params.get('rebuild_tree')) force = is_true(request.query_params.get('rebuild_tree'))
controller = UserGrantedTreeRefreshController(self.user) controller = UserGrantedTreeRefreshController(self.user)
controller.refresh_if_need(force) controller.refresh_if_need(force)

View File

@ -1,6 +1,7 @@
from django_filters import rest_framework as filters from django_filters import rest_framework as filters
from django.db.models import QuerySet, Q from django.db.models import QuerySet, Q
from common.db.models import UnionQuerySet
from common.drf.filters import BaseFilterSet from common.drf.filters import BaseFilterSet
from common.utils import get_object_or_none from common.utils import get_object_or_none
from users.models import User, UserGroup from users.models import User, UserGroup
@ -134,13 +135,15 @@ class AssetPermissionFilter(PermissionBaseFilter):
if not _nodes: if not _nodes:
return queryset.none() return queryset.none()
node = _nodes.get()
if not is_query_all: if not is_query_all:
queryset = queryset.filter(nodes__in=_nodes) queryset = queryset.filter(nodes=node)
return queryset return queryset
nodes = set(_nodes) nodeids = node.get_ancestors(with_self=True).values_list('id', flat=True)
for node in _nodes: nodeids = list(nodeids)
nodes |= set(node.get_ancestors(with_self=True))
queryset = queryset.filter(nodes__in=nodes) queryset = queryset.filter(nodes__in=nodeids)
return queryset return queryset
def filter_asset(self, queryset): def filter_asset(self, queryset):
@ -159,21 +162,26 @@ class AssetPermissionFilter(PermissionBaseFilter):
return queryset return queryset
if not assets: if not assets:
return queryset.none() return queryset.none()
if not is_query_all: asset = assets.get()
queryset = queryset.filter(assets__in=assets)
return queryset
inherit_all_nodes = set()
inherit_nodes_keys = assets.all().values_list('nodes__key', flat=True)
for key in inherit_nodes_keys: if not is_query_all:
if key is None: queryset = queryset.filter(assets=asset)
continue return queryset
inherit_all_nodekeys = set()
inherit_nodekeys = asset.nodes.values_list('key', flat=True)
for key in inherit_nodekeys:
ancestor_keys = Node.get_node_ancestor_keys(key, with_self=True) ancestor_keys = Node.get_node_ancestor_keys(key, with_self=True)
inherit_all_nodes.update(ancestor_keys) inherit_all_nodekeys.update(ancestor_keys)
queryset = queryset.filter(
Q(assets__in=assets) | Q(nodes__key__in=inherit_all_nodes) inherit_all_nodeids = Node.objects.filter(key__in=inherit_all_nodekeys).values_list('id', flat=True)
).distinct() inherit_all_nodeids = list(inherit_all_nodeids)
return queryset
qs1 = queryset.filter(assets=asset).distinct()
qs2 = queryset.filter(nodes__id__in=inherit_all_nodeids).distinct()
qs = UnionQuerySet(qs1, qs2)
return qs
def filter_effective(self, queryset): def filter_effective(self, queryset):
is_effective = self.get_query_param('is_effective') is_effective = self.get_query_param('is_effective')

View File

@ -2,10 +2,10 @@ from common.utils.lock import DistributedLock
class UserGrantedTreeRebuildLock(DistributedLock): class UserGrantedTreeRebuildLock(DistributedLock):
name_template = 'perms.user.asset.node.tree.rebuid.<org_id:{org_id}>.<user_id:{user_id}>' name_template = 'perms.user.asset.node.tree.rebuid.<user_id:{user_id}>'
def __init__(self, org_id, user_id): def __init__(self, user_id):
name = self.name_template.format( name = self.name_template.format(
org_id=org_id, user_id=user_id user_id=user_id
) )
super().__init__(name=name) super().__init__(name=name, release_on_transaction_commit=True)

View File

@ -193,25 +193,11 @@ class PermNode(Node):
node_from = '' node_from = ''
granted_assets_amount = 0 granted_assets_amount = 0
# 提供可以设置 资产数量的字段
_assets_amount = None
annotate_granted_node_rel_fields = { annotate_granted_node_rel_fields = {
'granted_assets_amount': F('granted_node_rels__node_assets_amount'), 'granted_assets_amount': F('granted_node_rels__node_assets_amount'),
'node_from': F('granted_node_rels__node_from') 'node_from': F('granted_node_rels__node_from')
} }
@property
def assets_amount(self):
_assets_amount = getattr(self, '_assets_amount')
if isinstance(_assets_amount, int):
return _assets_amount
return super().assets_amount
@assets_amount.setter
def assets_amount(self, value):
self._assets_amount = value
def use_granted_assets_amount(self): def use_granted_assets_amount(self):
self.assets_amount = self.granted_assets_amount self.assets_amount = self.granted_assets_amount

View File

@ -8,6 +8,7 @@ from django.db.models import Q
from django.utils import timezone from django.utils import timezone
from orgs.mixins.models import OrgModelMixin from orgs.mixins.models import OrgModelMixin
from common.db.models import UnionQuerySet
from common.utils import date_expired_default, lazyproperty from common.utils import date_expired_default, lazyproperty
from orgs.mixins.models import OrgManager from orgs.mixins.models import OrgManager
@ -100,10 +101,15 @@ class BasePermission(OrgModelMixin):
from users.models import User from users.models import User
users_id = self.users.all().values_list('id', flat=True) users_id = self.users.all().values_list('id', flat=True)
groups_id = self.user_groups.all().values_list('id', flat=True) groups_id = self.user_groups.all().values_list('id', flat=True)
users = User.objects.filter(
Q(id__in=users_id) | Q(groups__id__in=groups_id) users_id = list(users_id)
).distinct() groups_id = list(groups_id)
return users
qs1 = User.objects.filter(id__in=users_id).distinct()
qs2 = User.objects.filter(groups__id__in=groups_id).distinct()
qs = UnionQuerySet(qs1, qs2)
return qs
@lazyproperty @lazyproperty
def users_amount(self): def users_amount(self):

View File

@ -1,18 +1,16 @@
from collections import defaultdict from collections import defaultdict
from typing import List, Tuple from typing import List, Tuple
from functools import reduce, partial
from common.utils import isinstance_method
from django.core.cache import cache from django.core.cache import cache
from django.conf import settings from django.conf import settings
from django.db.models import Q, QuerySet from django.db.models import Q, QuerySet
from common.db.models import output_as_string from common.db.models import output_as_string, UnionQuerySet
from common.utils.common import lazyproperty, timeit, Time from common.utils.common import lazyproperty, timeit
from assets.utils import NodeAssetsUtil from assets.utils import NodeAssetsUtil
from common.utils import get_logger from common.utils import get_logger
from common.decorator import on_transaction_commit from common.decorator import on_transaction_commit
from orgs.utils import tmp_to_org, current_org, ensure_in_real_or_default_org from orgs.utils import tmp_to_org, current_org, ensure_in_real_or_default_org, tmp_to_root_org
from assets.models import ( from assets.models import (
Asset, FavoriteAsset, AssetQuerySet, NodeQuerySet Asset, FavoriteAsset, AssetQuerySet, NodeQuerySet
) )
@ -50,84 +48,10 @@ def get_user_all_asset_perm_ids(user) -> set:
asset_perm_ids = AssetPermission.objects.filter( asset_perm_ids = AssetPermission.objects.filter(
id__in=asset_perm_ids).valid().values_list('id', flat=True) id__in=asset_perm_ids).valid().values_list('id', flat=True)
asset_perm_ids = set(asset_perm_ids)
return asset_perm_ids return asset_perm_ids
class UnionQuerySet(QuerySet):
after_union = ['order_by']
not_return_qs = [
'query', 'get', 'create', 'get_or_create',
'update_or_create', 'bulk_create', 'count',
'latest', 'earliest', 'first', 'last', 'aggregate',
'exists', 'update', 'delete', 'as_manager', 'explain',
]
def __init__(self, *queryset_list):
self.queryset_list = queryset_list
self.after_union_items = []
self.before_union_items = []
def __execute(self):
queryset_list = []
for qs in self.queryset_list:
for attr, args, kwargs in self.before_union_items:
qs = getattr(qs, attr)(*args, **kwargs)
queryset_list.append(qs)
union_qs = reduce(lambda x, y: x.union(y), queryset_list)
for attr, args, kwargs in self.after_union_items:
union_qs = getattr(union_qs, attr)(*args, **kwargs)
return union_qs
def __before_union_perform(self, item, *args, **kwargs):
self.before_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __after_union_perform(self, item, *args, **kwargs):
self.after_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __clone(self, *queryset_list):
uqs = UnionQuerySet(*queryset_list)
uqs.after_union_items = self.after_union_items
uqs.before_union_items = self.before_union_items
return uqs
def __getattribute__(self, item):
if item.startswith('__') or item in UnionQuerySet.__dict__ or item in [
'queryset_list', 'after_union_items', 'before_union_items'
]:
return object.__getattribute__(self, item)
if item in UnionQuerySet.not_return_qs:
return getattr(self.__execute(), item)
origin_item = object.__getattribute__(self, 'queryset_list')[0]
origin_attr = getattr(origin_item, item, None)
if not isinstance_method(origin_attr):
return getattr(self.__execute(), item)
if item in UnionQuerySet.after_union:
attr = partial(self.__after_union_perform, item)
else:
attr = partial(self.__before_union_perform, item)
return attr
def __getitem__(self, item):
return self.__execute()[item]
def __iter__(self):
return iter(self.__execute())
@classmethod
def test_it(cls):
from assets.models import Asset
assets1 = Asset.objects.filter(hostname__startswith='a')
assets2 = Asset.objects.filter(hostname__startswith='b')
qs = cls(assets1, assets2)
return qs
class QuerySetStage: class QuerySetStage:
def __init__(self): def __init__(self):
self._prefetch_related = set() self._prefetch_related = set()
@ -273,11 +197,11 @@ class UserGrantedTreeRefreshController:
ret = p.execute() ret = p.execute()
builded_orgs_id = {org_id.decode() for org_id in ret[0]} builded_orgs_id = {org_id.decode() for org_id in ret[0]}
ids = orgs_id - builded_orgs_id ids = orgs_id - builded_orgs_id
orgs = [] orgs = set()
if Organization.DEFAULT_ID in ids: if Organization.DEFAULT_ID in ids:
ids.remove(Organization.DEFAULT_ID) ids.remove(Organization.DEFAULT_ID)
orgs.append(Organization.default()) orgs.add(Organization.default())
orgs.extend(Organization.objects.filter(id__in=ids)) orgs.update(Organization.objects.filter(id__in=ids))
logger.info(f'Need rebuild orgs are {orgs}, builed orgs are {ret[0]}, all orgs are {orgs_id}') logger.info(f'Need rebuild orgs are {orgs}, builed orgs are {ret[0]}, all orgs are {orgs_id}')
return orgs return orgs
@ -292,7 +216,7 @@ class UserGrantedTreeRefreshController:
key = cls.key_template.format(user_id=user_id) key = cls.key_template.format(user_id=user_id)
p.srem(key, *org_ids) p.srem(key, *org_ids)
p.execute() p.execute()
logger.info(f'Remove orgs from users builded tree, users:{users_id} orgs:{orgs_id}') logger.info(f'Remove orgs from users builded tree: users:{users_id} orgs:{orgs_id}')
@classmethod @classmethod
def add_need_refresh_orgs_for_users(cls, orgs_id, users_id): def add_need_refresh_orgs_for_users(cls, orgs_id, users_id):
@ -364,29 +288,37 @@ class UserGrantedTreeRefreshController:
@lazyproperty @lazyproperty
def orgs_id(self): def orgs_id(self):
ret = [org.id for org in self.orgs] ret = {org.id for org in self.orgs}
return ret return ret
@lazyproperty @lazyproperty
def orgs(self): def orgs(self):
orgs = [*self.user.orgs.all(), Organization.default()] orgs = {*self.user.orgs.all().distinct(), Organization.default()}
return orgs return orgs
@timeit @timeit
def refresh_if_need(self, force=False): def refresh_if_need(self, force=False):
user = self.user user = self.user
exists = UserAssetGrantedTreeNodeRelation.objects.filter(user=user).exists()
if force or not exists: with UserGrantedTreeRebuildLock(user_id=user.id):
orgs = self.orgs with tmp_to_root_org():
self.set_all_orgs_as_builed() orgids = self.orgs_id.copy()
else: orgids.remove(Organization.default().id)
orgs = self.get_need_refresh_orgs_and_fill_up() orgids.add('') # 添加 default
for org in orgs: UserAssetGrantedTreeNodeRelation.objects.filter(user=user).exclude(org_id__in=orgids).delete()
with tmp_to_org(org): exists = UserAssetGrantedTreeNodeRelation.objects.filter(user=user).exists()
utils = UserGrantedTreeBuildUtils(user)
utils.rebuild_user_granted_tree() if force or not exists:
orgs = self.orgs
self.set_all_orgs_as_builed()
else:
orgs = self.get_need_refresh_orgs_and_fill_up()
for org in orgs:
with tmp_to_org(org):
utils = UserGrantedTreeBuildUtils(user)
utils.rebuild_user_granted_tree()
class UserGrantedUtilsBase: class UserGrantedUtilsBase:
@ -394,7 +326,7 @@ class UserGrantedUtilsBase:
def __init__(self, user, asset_perm_ids=None): def __init__(self, user, asset_perm_ids=None):
self.user = user self.user = user
self._asset_perm_ids = asset_perm_ids self._asset_perm_ids = asset_perm_ids and set(asset_perm_ids)
@lazyproperty @lazyproperty
def asset_perm_ids(self) -> set: def asset_perm_ids(self) -> set:
@ -431,24 +363,26 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
@timeit @timeit
@ensure_in_real_or_default_org @ensure_in_real_or_default_org
def rebuild_user_granted_tree(self): def rebuild_user_granted_tree(self):
logger.info(f'Rebuild user:{self.user} tree in org:{current_org}') """
注意调用该方法一定要被 `UserGrantedTreeRebuildLock` 锁住
"""
logger.info(f'Rebuild user tree: user={self.user} org={current_org}')
user = self.user user = self.user
org_id = current_org.id
with UserGrantedTreeRebuildLock(org_id, user.id): # 先删除旧的授权树🌲
# 先删除旧的授权树🌲 UserAssetGrantedTreeNodeRelation.objects.filter(user=user).delete()
UserAssetGrantedTreeNodeRelation.objects.filter(user=user).delete()
if not self.asset_perm_ids: if not self.asset_perm_ids:
# 没有授权直接返回 # 没有授权直接返回
return return
nodes = self.compute_perm_nodes_tree() nodes = self.compute_perm_nodes_tree()
self.compute_node_assets_amount(nodes) self.compute_node_assets_amount(nodes)
if not nodes: if not nodes:
return return
self.create_mapping_nodes(nodes) self.create_mapping_nodes(nodes)
@timeit @timeit
def compute_perm_nodes_tree(self, node_only_fields=NODE_ONLY_FIELDS) -> list: def compute_perm_nodes_tree(self, node_only_fields=NODE_ONLY_FIELDS) -> list:
@ -587,6 +521,8 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
node_asset_pairs = list(node_asset_pairs) node_asset_pairs = list(node_asset_pairs)
for node_id, asset_id in node_asset_pairs: for node_id, asset_id in node_asset_pairs:
if node_id not in node_id_key_mapper:
continue
nkey = node_id_key_mapper[node_id] nkey = node_id_key_mapper[node_id]
nodekey_assetsid_mapper[nkey].add(asset_id) nodekey_assetsid_mapper[nkey].add(asset_id)
@ -695,6 +631,10 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
).filter( ).filter(
Q(node_key__startswith=f'{node.key}:') Q(node_key__startswith=f'{node.key}:')
).only('node_id', 'node_key') ).only('node_id', 'node_key')
for n in granted_nodes:
n.id = n.node_id
node_assets = PermNode.get_nodes_all_assets(*granted_nodes) node_assets = PermNode.get_nodes_all_assets(*granted_nodes)
# 查询该节点下的资产授权节点 # 查询该节点下的资产授权节点