perf(project): 优化命名的风格 (#5693)

perf: 修改错误的地

perf: 优化写错的几个

Co-authored-by: ibuler <ibuler@qq.com>
pull/5698/head^2
fit2bot 2021-03-08 10:08:51 +08:00 committed by GitHub
parent 935947c97a
commit 0aa2c2016f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 272 additions and 273 deletions

View File

@ -77,8 +77,8 @@ class SerializeApplicationToTreeNodeMixin:
@staticmethod @staticmethod
def filter_organizations(applications): def filter_organizations(applications):
organizations_id = set(applications.values_list('org_id', flat=True)) organization_ids = set(applications.values_list('org_id', flat=True))
organizations = [Organization.get_instance(org_id) for org_id in organizations_id] organizations = [Organization.get_instance(org_id) for org_id in organization_ids]
return organizations return organizations
def serialize_applications_with_org(self, applications): def serialize_applications_with_org(self, applications):

View File

@ -223,8 +223,8 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
instance = self.get_object() instance = self.get_object()
nodes_id = request.data.get("nodes") node_ids = request.data.get("nodes")
children = Node.objects.filter(id__in=nodes_id) children = Node.objects.filter(id__in=node_ids)
for node in children: for node in children:
node.parent = instance node.parent = instance
return Response("OK") return Response("OK")

View File

@ -87,13 +87,13 @@ class SystemUserTaskApi(generics.CreateAPIView):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.SystemUserTaskSerializer serializer_class = serializers.SystemUserTaskSerializer
def do_push(self, system_user, assets_id=None): def do_push(self, system_user, asset_ids=None):
if assets_id is None: if asset_ids is None:
task = push_system_user_to_assets_manual.delay(system_user) task = push_system_user_to_assets_manual.delay(system_user)
else: else:
username = self.request.query_params.get('username') username = self.request.query_params.get('username')
task = push_system_user_to_assets.delay( task = push_system_user_to_assets.delay(
system_user.id, assets_id, username=username system_user.id, asset_ids, username=username
) )
return task return task
@ -114,9 +114,9 @@ class SystemUserTaskApi(generics.CreateAPIView):
system_user = self.get_object() system_user = self.get_object()
if action == 'push': if action == 'push':
assets = [asset] if asset else assets assets = [asset] if asset else assets
assets_id = [asset.id for asset in assets] asset_ids = [asset.id for asset in assets]
assets_id = assets_id if assets_id else None asset_ids = asset_ids if asset_ids else None
task = self.do_push(system_user, assets_id) task = self.do_push(system_user, asset_ids)
else: else:
task = self.do_test(system_user) task = self.do_test(system_user)
data = getattr(serializer, '_data', {}) data = getattr(serializer, '_data', {})

View File

@ -40,7 +40,7 @@ class BaseBackend:
return values return values
@staticmethod @staticmethod
def make_assets_as_id(assets): def make_assets_as_ids(assets):
if not assets: if not assets:
return [] return []
if isinstance(assets[0], Asset): if isinstance(assets[0], Asset):

View File

@ -69,9 +69,9 @@ class DBBackend(BaseBackend):
self.queryset = self.queryset.filter(union_id=union_id) self.queryset = self.queryset.filter(union_id=union_id)
def _filter_assets(self, assets): def _filter_assets(self, assets):
assets_id = self.make_assets_as_id(assets) asset_ids = self.make_assets_as_ids(assets)
if assets_id: if asset_ids:
self.queryset = self.queryset.filter(asset_id__in=assets_id) self.queryset = self.queryset.filter(asset_id__in=asset_ids)
def _filter_node(self, node): def _filter_node(self, node):
pass pass

View File

@ -16,5 +16,5 @@ class FavoriteAsset(CommonModelMixin):
unique_together = ('user', 'asset') unique_together = ('user', 'asset')
@classmethod @classmethod
def get_user_favorite_assets_id(cls, user): def get_user_favorite_asset_ids(cls, user):
return cls.objects.filter(user=user).values_list('asset', flat=True) return cls.objects.filter(user=user).values_list('asset', flat=True)

View File

@ -263,38 +263,38 @@ class NodeAllAssetsMappingMixin:
orgid_nodekey_assetsid_mapping = defaultdict(dict) orgid_nodekey_assetsid_mapping = defaultdict(dict)
@classmethod @classmethod
def get_node_all_assets_id_mapping(cls, org_id): def get_node_all_asset_ids_mapping(cls, org_id):
_mapping = cls.get_node_all_assets_id_mapping_from_memory(org_id) _mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id)
if _mapping: if _mapping:
return _mapping return _mapping
_mapping = cls.get_node_all_assets_id_mapping_from_cache_or_generate_to_cache(org_id) _mapping = cls.get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(org_id)
cls.set_node_all_assets_id_mapping_to_memory(org_id, mapping=_mapping) cls.set_node_all_asset_ids_mapping_to_memory(org_id, mapping=_mapping)
return _mapping return _mapping
# from memory # from memory
@classmethod @classmethod
def get_node_all_assets_id_mapping_from_memory(cls, org_id): def get_node_all_asset_ids_mapping_from_memory(cls, org_id):
mapping = cls.orgid_nodekey_assetsid_mapping.get(org_id, {}) mapping = cls.orgid_nodekey_assetsid_mapping.get(org_id, {})
return mapping return mapping
@classmethod @classmethod
def set_node_all_assets_id_mapping_to_memory(cls, org_id, mapping): def set_node_all_asset_ids_mapping_to_memory(cls, org_id, mapping):
cls.orgid_nodekey_assetsid_mapping[org_id] = mapping cls.orgid_nodekey_assetsid_mapping[org_id] = mapping
@classmethod @classmethod
def expire_node_all_assets_id_mapping_from_memory(cls, org_id): def expire_node_all_asset_ids_mapping_from_memory(cls, org_id):
org_id = str(org_id) org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None) cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
# get order: from memory -> (from cache -> to generate) # get order: from memory -> (from cache -> to generate)
@classmethod @classmethod
def get_node_all_assets_id_mapping_from_cache_or_generate_to_cache(cls, org_id): def get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(cls, org_id):
mapping = cls.get_node_all_assets_id_mapping_from_cache(org_id) mapping = cls.get_node_all_asset_ids_mapping_from_cache(org_id)
if mapping: if mapping:
return mapping return mapping
lock_key = f'KEY_LOCK_GENERATE_ORG_{org_id}_NODE_ALL_ASSETS_ID_MAPPING' lock_key = f'KEY_LOCK_GENERATE_ORG_{org_id}_NODE_ALL_ASSET_ids_MAPPING'
logger.info(f'Thread[{threading.get_ident()}] acquiring lock[{lock_key}] ...') logger.info(f'Thread[{threading.get_ident()}] acquiring lock[{lock_key}] ...')
with DistributedLock(lock_key): with DistributedLock(lock_key):
logger.info(f'Thread[{threading.get_ident()}] acquire lock[{lock_key}] ok') logger.info(f'Thread[{threading.get_ident()}] acquire lock[{lock_key}] ok')
@ -303,67 +303,67 @@ class NodeAllAssetsMappingMixin:
# 这里最好先判断内存中有没有,防止同一进程的多个线程重复从 cache 中获取数据, # 这里最好先判断内存中有没有,防止同一进程的多个线程重复从 cache 中获取数据,
# 但逻辑过于繁琐,直接判断 cache 吧 # 但逻辑过于繁琐,直接判断 cache 吧
_mapping = cls.get_node_all_assets_id_mapping_from_cache(org_id) _mapping = cls.get_node_all_asset_ids_mapping_from_cache(org_id)
if _mapping: if _mapping:
return _mapping return _mapping
_mapping = cls.generate_node_all_assets_id_mapping(org_id) _mapping = cls.generate_node_all_asset_ids_mapping(org_id)
cls.set_node_all_assets_id_mapping_to_cache(org_id=org_id, mapping=_mapping) cls.set_node_all_asset_ids_mapping_to_cache(org_id=org_id, mapping=_mapping)
return _mapping return _mapping
@classmethod @classmethod
def get_node_all_assets_id_mapping_from_cache(cls, org_id): def get_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id) cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
mapping = cache.get(cache_key) mapping = cache.get(cache_key)
return mapping return mapping
@classmethod @classmethod
def set_node_all_assets_id_mapping_to_cache(cls, org_id, mapping): def set_node_all_asset_ids_mapping_to_cache(cls, org_id, mapping):
cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id) cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None) cache.set(cache_key, mapping, timeout=None)
@classmethod @classmethod
def expire_node_all_assets_id_mapping_from_cache(cls, org_id): def expire_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_assets_id_mapping(org_id) cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.delete(cache_key) cache.delete(cache_key)
@staticmethod @staticmethod
def _get_cache_key_for_node_all_assets_id_mapping(org_id): def _get_cache_key_for_node_all_asset_ids_mapping(org_id):
return 'ASSETS_ORG_NODE_ALL_ASSETS_ID_MAPPING_{}'.format(org_id) return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id)
@classmethod @classmethod
def generate_node_all_assets_id_mapping(cls, org_id): def generate_node_all_asset_ids_mapping(cls, org_id):
from .asset import Asset from .asset import Asset
t1 = time.time() t1 = time.time()
with tmp_to_org(org_id): with tmp_to_org(org_id):
nodes_id_key = Node.objects.annotate( node_ids_key = Node.objects.annotate(
char_id=output_as_string('id') char_id=output_as_string('id')
).values_list('char_id', 'key') ).values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢) # * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_assets_id = Asset.nodes.through.objects.all() \ nodes_asset_ids = Asset.nodes.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \ .annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \ .annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id') .values_list('char_node_id', 'char_asset_id')
node_id_ancestor_keys_mapping = { node_id_ancestor_keys_mapping = {
node_id: cls.get_node_ancestor_keys(node_key, with_self=True) node_id: cls.get_node_ancestor_keys(node_key, with_self=True)
for node_id, node_key in nodes_id_key for node_id, node_key in node_ids_key
} }
nodeid_assetsid_mapping = defaultdict(set) nodeid_assetsid_mapping = defaultdict(set)
for node_id, asset_id in nodes_assets_id: for node_id, asset_id in nodes_asset_ids:
nodeid_assetsid_mapping[node_id].add(asset_id) nodeid_assetsid_mapping[node_id].add(asset_id)
t2 = time.time() t2 = time.time()
mapping = defaultdict(set) mapping = defaultdict(set)
for node_id, node_key in nodes_id_key: for node_id, node_key in node_ids_key:
assets_id = nodeid_assetsid_mapping[node_id] asset_ids = nodeid_assetsid_mapping[node_id]
node_ancestor_keys = node_id_ancestor_keys_mapping[node_id] node_ancestor_keys = node_id_ancestor_keys_mapping[node_id]
for ancestor_key in node_ancestor_keys: for ancestor_key in node_ancestor_keys:
mapping[ancestor_key].update(assets_id) mapping[ancestor_key].update(asset_ids)
t3 = time.time() t3 = time.time()
logger.debug('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2-t1, t3-t2)) logger.debug('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2-t1, t3-t2))
@ -407,10 +407,10 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
return self.get_all_assets().valid() return self.get_all_assets().valid()
@classmethod @classmethod
def get_nodes_all_assets_ids_by_keys(cls, nodes_keys): def get_nodes_all_asset_ids_by_keys(cls, nodes_keys):
nodes = Node.objects.filter(key__in=nodes_keys) nodes = Node.objects.filter(key__in=nodes_keys)
assets_ids = cls.get_nodes_all_assets(*nodes).values_list('id', flat=True) asset_ids = cls.get_nodes_all_assets(*nodes).values_list('id', flat=True)
return assets_ids return asset_ids
@classmethod @classmethod
def get_nodes_all_assets(cls, *nodes): def get_nodes_all_assets(cls, *nodes):
@ -425,16 +425,16 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
node_ids.update(_ids) node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct() return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
def get_all_assets_id(self): def get_all_asset_ids(self):
assets_id = self.get_all_assets_id_by_node_key(org_id=self.org_id, node_key=self.key) asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)
return set(assets_id) return set(asset_ids)
@classmethod @classmethod
def get_all_assets_id_by_node_key(cls, org_id, node_key): def get_all_asset_ids_by_node_key(cls, org_id, node_key):
org_id = str(org_id) org_id = str(org_id)
nodekey_assetsid_mapping = cls.get_node_all_assets_id_mapping(org_id) nodekey_assetsid_mapping = cls.get_node_all_asset_ids_mapping(org_id)
assets_id = nodekey_assetsid_mapping.get(node_key, []) asset_ids = nodekey_assetsid_mapping.get(node_key, [])
return set(assets_id) return set(asset_ids)
class SomeNodesMixin: class SomeNodesMixin:

View File

@ -198,10 +198,10 @@ class SystemUser(BaseUser):
def get_all_assets(self): def get_all_assets(self):
from assets.models import Node from assets.models import Node
nodes_keys = self.nodes.all().values_list('key', flat=True) nodes_keys = self.nodes.all().values_list('key', flat=True)
assets_ids = set(self.assets.all().values_list('id', flat=True)) asset_ids = set(self.assets.all().values_list('id', flat=True))
nodes_assets_ids = Node.get_nodes_all_assets_ids_by_keys(nodes_keys) nodes_asset_ids = Node.get_nodes_all_asset_ids_by_keys(nodes_keys)
assets_ids.update(nodes_assets_ids) asset_ids.update(nodes_asset_ids)
assets = Asset.objects.filter(id__in=assets_ids) assets = Asset.objects.filter(id__in=asset_ids)
return assets return assets
@classmethod @classmethod

View File

@ -82,13 +82,13 @@ def on_system_user_assets_change(instance, action, model, pk_set, **kwargs):
return return
logger.debug("System user assets change signal recv: {}".format(instance)) logger.debug("System user assets change signal recv: {}".format(instance))
if model == Asset: if model == Asset:
system_users_id = [instance.id] system_user_ids = [instance.id]
assets_id = pk_set asset_ids = pk_set
else: else:
system_users_id = pk_set system_user_ids = pk_set
assets_id = [instance.id] asset_ids = [instance.id]
for system_user_id in system_users_id: for system_user_id in system_user_ids:
push_system_user_to_assets.delay(system_user_id, assets_id) push_system_user_to_assets.delay(system_user_id, asset_ids)
@receiver(m2m_changed, sender=SystemUser.users.through) @receiver(m2m_changed, sender=SystemUser.users.through)

View File

@ -22,7 +22,7 @@ logger = get_logger(__file__)
def get_node_assets_mapping_for_memory_pub_sub(): def get_node_assets_mapping_for_memory_pub_sub():
return RedisPubSub('fm.node_all_assets_id_memory_mapping') return RedisPubSub('fm.node_all_asset_ids_memory_mapping')
class NodeAssetsMappingForMemoryPubSub(LazyObject): class NodeAssetsMappingForMemoryPubSub(LazyObject):
@ -42,7 +42,7 @@ def expire_node_assets_mapping_for_memory(org_id):
"Expire node assets id mapping from cache of org={}, pid={}" "Expire node assets id mapping from cache of org={}, pid={}"
"".format(org_id, os.getpid()) "".format(org_id, os.getpid())
) )
Node.expire_node_all_assets_id_mapping_from_cache(org_id) Node.expire_node_all_asset_ids_mapping_from_cache(org_id)
@receiver(post_save, sender=Node) @receiver(post_save, sender=Node)
@ -78,7 +78,7 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
if message["type"] != "message": if message["type"] != "message":
continue continue
org_id = message['data'].decode() org_id = message['data'].decode()
Node.expire_node_all_assets_id_mapping_from_memory(org_id) Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
logger.debug( logger.debug(
"Expire node assets id mapping from memory of org={}, pid={}" "Expire node assets id mapping from memory of org={}, pid={}"
"".format(str(org_id), os.getpid()) "".format(str(org_id), os.getpid())

View File

@ -233,18 +233,18 @@ def push_system_user_util(system_user, assets, task_name, username=None):
print(_("Hosts count: {}").format(len(_assets))) print(_("Hosts count: {}").format(len(_assets)))
id_asset_map = {_asset.id: _asset for _asset in _assets} id_asset_map = {_asset.id: _asset for _asset in _assets}
assets_id = id_asset_map.keys() asset_ids = id_asset_map.keys()
no_special_auth = [] no_special_auth = []
special_auth_set = set() special_auth_set = set()
auth_books = AuthBook.objects.filter(username__in=usernames, asset_id__in=assets_id) auth_books = AuthBook.objects.filter(username__in=usernames, asset_id__in=asset_ids)
for auth_book in auth_books: for auth_book in auth_books:
special_auth_set.add((auth_book.username, auth_book.asset_id)) special_auth_set.add((auth_book.username, auth_book.asset_id))
for _username in usernames: for _username in usernames:
no_special_assets = [] no_special_assets = []
for asset_id in assets_id: for asset_id in asset_ids:
if (_username, asset_id) not in special_auth_set: if (_username, asset_id) not in special_auth_set:
no_special_assets.append(id_asset_map[asset_id]) no_special_assets.append(id_asset_map[asset_id])
if no_special_assets: if no_special_assets:
@ -289,12 +289,12 @@ def push_system_user_a_asset_manual(system_user, asset, username=None):
@shared_task(queue="ansible") @shared_task(queue="ansible")
@tmp_to_root_org() @tmp_to_root_org()
def push_system_user_to_assets(system_user_id, assets_id, username=None): def push_system_user_to_assets(system_user_id, asset_ids, username=None):
""" """
推送系统用户到指定的若干资产上 推送系统用户到指定的若干资产上
""" """
system_user = SystemUser.objects.get(id=system_user_id) system_user = SystemUser.objects.get(id=system_user_id)
assets = get_objects(Asset, assets_id) assets = get_objects(Asset, asset_ids)
task_name = _("Push system users to assets: {}").format(system_user.name) task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name, username=username) return push_system_user_util(system_user, assets, task_name, username=username)

View File

@ -183,9 +183,9 @@ class UserConnectionTokenViewSet(RootOrgViewMixin, SerializerMixin2, GenericView
@staticmethod @staticmethod
def _get_asset_secret_detail(value, user, system_user): def _get_asset_secret_detail(value, user, system_user):
from assets.models import Asset from assets.models import Asset
from perms.utils.asset import get_asset_system_users_id_with_actions_by_user from perms.utils.asset import get_asset_system_user_ids_with_actions_by_user
asset = get_object_or_404(Asset, id=value.get('asset')) asset = get_object_or_404(Asset, id=value.get('asset'))
systemuserid_actions_mapper = get_asset_system_users_id_with_actions_by_user(user, asset) systemuserid_actions_mapper = get_asset_system_user_ids_with_actions_by_user(user, asset)
actions = systemuserid_actions_mapper.get(system_user.id, []) actions = systemuserid_actions_mapper.get(system_user.id, [])
gateway = None gateway = None
if asset and asset.domain and asset.domain.has_gateway(): if asset and asset.domain and asset.domain.has_gateway():

View File

@ -13,7 +13,7 @@ from rest_framework.viewsets import GenericViewSet
from common.permissions import IsValidUser from common.permissions import IsValidUser
from .http import HttpResponseTemporaryRedirect from .http import HttpResponseTemporaryRedirect
from .const import KEY_CACHE_RESOURCES_ID from .const import KEY_CACHE_RESOURCE_IDS
from .utils import get_logger from .utils import get_logger
from .mixins import CommonApiMixin from .mixins import CommonApiMixin
@ -93,7 +93,7 @@ class ResourcesIDCacheApi(APIView):
spm = str(uuid.uuid4()) spm = str(uuid.uuid4())
resources = request.data.get('resources') resources = request.data.get('resources')
if resources is not None: if resources is not None:
cache_key = KEY_CACHE_RESOURCES_ID.format(spm) cache_key = KEY_CACHE_RESOURCE_IDS.format(spm)
cache.set(cache_key, resources, 300) cache.set(cache_key, resources, 300)
return Response({'spm': spm}) return Response({'spm': spm})

View File

@ -7,7 +7,7 @@ create_success_msg = _("%(name)s was created successfully")
update_success_msg = _("%(name)s was updated successfully") update_success_msg = _("%(name)s was updated successfully")
FILE_END_GUARD = ">>> Content End <<<" FILE_END_GUARD = ">>> Content End <<<"
celery_task_pre_key = "CELERY_" celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCES_ID = "RESOURCES_ID_{}" KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable # AD User AccountDisable
# https://blog.csdn.net/bytxl/article/details/17763975 # https://blog.csdn.net/bytxl/article/details/17763975

View File

@ -108,11 +108,11 @@ class IDSpmFilter(filters.BaseFilterBackend):
spm = request.query_params.get('spm') spm = request.query_params.get('spm')
if not spm: if not spm:
return queryset return queryset
cache_key = const.KEY_CACHE_RESOURCES_ID.format(spm) cache_key = const.KEY_CACHE_RESOURCE_IDS.format(spm)
resources_id = cache.get(cache_key) resource_ids = cache.get(cache_key)
if resources_id is None or not isinstance(resources_id, list): if resource_ids is None or not isinstance(resource_ids, list):
return queryset return queryset
queryset = queryset.filter(id__in=resources_id) queryset = queryset.filter(id__in=resource_ids)
return queryset return queryset

View File

@ -91,14 +91,14 @@ def _remove_users(model, users, org):
f'{m2m_field_name}__org_id': org.id f'{m2m_field_name}__org_id': org.id
}) })
object_id_users_id_map = defaultdict(set) object_id_user_ids_map = defaultdict(set)
m2m_field_attr_name = f'{m2m_field_name}_id' m2m_field_attr_name = f'{m2m_field_name}_id'
for relation in relations: for relation in relations:
object_id = getattr(relation, m2m_field_attr_name) object_id = getattr(relation, m2m_field_attr_name)
object_id_users_id_map[object_id].add(relation.user_id) object_id_user_ids_map[object_id].add(relation.user_id)
objects = model.objects.filter(id__in=object_id_users_id_map.keys()) objects = model.objects.filter(id__in=object_id_user_ids_map.keys())
send_m2m_change_signal = partial( send_m2m_change_signal = partial(
m2m_changed.send, m2m_changed.send,
sender=m2m_model, reverse=reverse, model=User, using=model.objects.db sender=m2m_model, reverse=reverse, model=User, using=model.objects.db
@ -107,7 +107,7 @@ def _remove_users(model, users, org):
for obj in objects: for obj in objects:
send_m2m_change_signal( send_m2m_change_signal(
instance=obj, instance=obj,
pk_set=object_id_users_id_map[obj.id], pk_set=object_id_user_ids_map[obj.id],
action=PRE_REMOVE action=PRE_REMOVE
) )
@ -116,7 +116,7 @@ def _remove_users(model, users, org):
for obj in objects: for obj in objects:
send_m2m_change_signal( send_m2m_change_signal(
instance=obj, instance=obj,
pk_set=object_id_users_id_map[obj.id], pk_set=object_id_user_ids_map[obj.id],
action=POST_REMOVE action=POST_REMOVE
) )

View File

@ -12,7 +12,7 @@ from orgs.utils import tmp_to_root_org
from applications.models import Application from applications.models import Application
from perms.utils.application.permission import ( from perms.utils.application.permission import (
has_application_system_permission, has_application_system_permission,
get_application_system_users_id get_application_system_user_ids
) )
from perms.api.asset.user_permission.mixin import RoleAdminMixin, RoleUserMixin from perms.api.asset.user_permission.mixin import RoleAdminMixin, RoleUserMixin
from common.permissions import IsOrgAdminOrAppUser from common.permissions import IsOrgAdminOrAppUser
@ -32,14 +32,14 @@ class GrantedApplicationSystemUsersMixin(ListAPIView):
only_fields = serializers.ApplicationSystemUserSerializer.Meta.only_fields only_fields = serializers.ApplicationSystemUserSerializer.Meta.only_fields
user: None user: None
def get_application_system_users_id(self, application): def get_application_system_user_ids(self, application):
return get_application_system_users_id(self.user, application) return get_application_system_user_ids(self.user, application)
def get_queryset(self): def get_queryset(self):
application_id = self.kwargs.get('application_id') application_id = self.kwargs.get('application_id')
application = get_object_or_404(Application, id=application_id) application = get_object_or_404(Application, id=application_id)
system_users_id = self.get_application_system_users_id(application) system_user_ids = self.get_application_system_user_ids(application)
system_users = SystemUser.objects.filter(id__in=system_users_id)\ system_users = SystemUser.objects.filter(id__in=system_user_ids)\
.only(*self.only_fields).order_by('priority') .only(*self.only_fields).order_by('priority')
return system_users return system_users

View File

@ -12,7 +12,7 @@ from perms.models import AssetPermission
from assets.models import Asset, Node from assets.models import Asset, Node
from perms.api.asset import user_permission as uapi from perms.api.asset import user_permission as uapi
from perms import serializers from perms import serializers
from perms.utils.asset.permission import get_asset_system_users_id_with_actions_by_group from perms.utils.asset.permission import get_asset_system_user_ids_with_actions_by_group
from assets.api.mixin import SerializeToTreeNodeMixin from assets.api.mixin import SerializeToTreeNodeMixin
from users.models import UserGroup from users.models import UserGroup
@ -41,12 +41,12 @@ class UserGroupGrantedAssetsApi(ListAPIView):
def get_queryset(self): def get_queryset(self):
user_group_id = self.kwargs.get('pk', '') user_group_id = self.kwargs.get('pk', '')
asset_perms_id = list(AssetPermission.objects.valid().filter( asset_perm_ids = list(AssetPermission.objects.valid().filter(
user_groups__id=user_group_id user_groups__id=user_group_id
).distinct().values_list('id', flat=True)) ).distinct().values_list('id', flat=True))
granted_node_keys = Node.objects.filter( granted_node_keys = Node.objects.filter(
granted_by_permissions__id__in=asset_perms_id, granted_by_permissions__id__in=asset_perm_ids,
).distinct().values_list('key', flat=True) ).distinct().values_list('key', flat=True)
granted_q = Q() granted_q = Q()
@ -54,7 +54,7 @@ class UserGroupGrantedAssetsApi(ListAPIView):
granted_q |= Q(nodes__key__startswith=f'{_key}:') granted_q |= Q(nodes__key__startswith=f'{_key}:')
granted_q |= Q(nodes__key=_key) granted_q |= Q(nodes__key=_key)
granted_q |= Q(granted_by_permissions__id__in=asset_perms_id) granted_q |= Q(granted_by_permissions__id__in=asset_perm_ids)
assets = Asset.objects.filter( assets = Asset.objects.filter(
granted_q granted_q
@ -89,12 +89,12 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
) )
return assets return assets
else: else:
asset_perms_id = list(AssetPermission.objects.valid().filter( asset_perm_ids = list(AssetPermission.objects.valid().filter(
user_groups__id=user_group_id user_groups__id=user_group_id
).distinct().values_list('id', flat=True)) ).distinct().values_list('id', flat=True))
granted_node_keys = Node.objects.filter( granted_node_keys = Node.objects.filter(
granted_by_permissions__id__in=asset_perms_id, granted_by_permissions__id__in=asset_perm_ids,
key__startswith=f'{node.key}:' key__startswith=f'{node.key}:'
).distinct().values_list('key', flat=True) ).distinct().values_list('key', flat=True)
@ -104,7 +104,7 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
granted_node_q |= Q(nodes__key=_key) granted_node_q |= Q(nodes__key=_key)
granted_asset_q = ( granted_asset_q = (
Q(granted_by_permissions__id__in=asset_perms_id) & Q(granted_by_permissions__id__in=asset_perm_ids) &
( (
Q(nodes__key__startswith=f'{node.key}:') | Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key) Q(nodes__key=node.key)
@ -148,16 +148,16 @@ class UserGroupGrantedNodeChildrenAsTreeApi(SerializeToTreeNodeMixin, ListAPIVie
group_id = self.kwargs.get('pk') group_id = self.kwargs.get('pk')
node_key = self.request.query_params.get('key', None) node_key = self.request.query_params.get('key', None)
asset_perms_id = list(AssetPermission.objects.valid().filter( asset_perm_ids = list(AssetPermission.objects.valid().filter(
user_groups__id=group_id user_groups__id=group_id
).distinct().values_list('id', flat=True)) ).distinct().values_list('id', flat=True))
granted_keys = Node.objects.filter( granted_keys = Node.objects.filter(
granted_by_permissions__id__in=asset_perms_id granted_by_permissions__id__in=asset_perm_ids
).values_list('key', flat=True) ).values_list('key', flat=True)
asset_granted_keys = Node.objects.filter( asset_granted_keys = Node.objects.filter(
assets__granted_by_permissions__id__in=asset_perms_id assets__granted_by_permissions__id__in=asset_perm_ids
).values_list('key', flat=True) ).values_list('key', flat=True)
if node_key is None: if node_key is None:
@ -188,5 +188,5 @@ class UserGroupGrantedNodeChildrenAsTreeApi(SerializeToTreeNodeMixin, ListAPIVie
class UserGroupGrantedAssetSystemUsersApi(UserGroupMixin, uapi.UserGrantedAssetSystemUsersForAdminApi): class UserGroupGrantedAssetSystemUsersApi(UserGroupMixin, uapi.UserGrantedAssetSystemUsersForAdminApi):
def get_asset_system_users_id_with_actions(self, asset): def get_asset_system_user_ids_with_actions(self, asset):
return get_asset_system_users_id_with_actions_by_group(self.group, asset) return get_asset_system_user_ids_with_actions_by_group(self.group, asset)

View File

@ -10,7 +10,7 @@ from rest_framework.generics import (
) )
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org
from perms.utils.asset.permission import get_asset_system_users_id_with_actions_by_user from perms.utils.asset.permission import get_asset_system_user_ids_with_actions_by_user
from common.permissions import IsOrgAdminOrAppUser, IsOrgAdmin, IsValidUser from common.permissions import IsOrgAdminOrAppUser, IsOrgAdmin, IsValidUser
from common.utils import get_logger, lazyproperty from common.utils import get_logger, lazyproperty
@ -53,7 +53,7 @@ class GetUserAssetPermissionActionsApi(RetrieveAPIView):
asset = get_object_or_404(Asset, id=asset_id) asset = get_object_or_404(Asset, id=asset_id)
system_user = get_object_or_404(SystemUser, id=system_id) system_user = get_object_or_404(SystemUser, id=system_id)
system_users_actions = get_asset_system_users_id_with_actions_by_user(self.get_user(), asset) system_users_actions = get_asset_system_user_ids_with_actions_by_user(self.get_user(), asset)
actions = system_users_actions.get(system_user.id) actions = system_users_actions.get(system_user.id)
return {"actions": actions} return {"actions": actions}
@ -84,7 +84,7 @@ class ValidateUserAssetPermissionApi(APIView):
asset = get_object_or_404(Asset, id=asset_id) asset = get_object_or_404(Asset, id=asset_id)
system_user = get_object_or_404(SystemUser, id=system_id) system_user = get_object_or_404(SystemUser, id=system_id)
system_users_actions = get_asset_system_users_id_with_actions_by_user(self.get_user(), asset) system_users_actions = get_asset_system_user_ids_with_actions_by_user(self.get_user(), asset)
actions = system_users_actions.get(system_user.id) actions = system_users_actions.get(system_user.id)
if actions is None: if actions is None:
return Response({'msg': False}, status=403) return Response({'msg': False}, status=403)
@ -111,15 +111,15 @@ class UserGrantedAssetSystemUsersForAdminApi(ListAPIView):
user_id = self.kwargs.get('pk') user_id = self.kwargs.get('pk')
return User.objects.get(id=user_id) return User.objects.get(id=user_id)
def get_asset_system_users_id_with_actions(self, asset): def get_asset_system_user_ids_with_actions(self, asset):
return get_asset_system_users_id_with_actions_by_user(self.user, asset) return get_asset_system_user_ids_with_actions_by_user(self.user, asset)
def get_queryset(self): def get_queryset(self):
asset_id = self.kwargs.get('asset_id') asset_id = self.kwargs.get('asset_id')
asset = get_object_or_404(Asset, id=asset_id) asset = get_object_or_404(Asset, id=asset_id)
system_users_with_actions = self.get_asset_system_users_id_with_actions(asset) system_users_with_actions = self.get_asset_system_user_ids_with_actions(asset)
system_users_id = system_users_with_actions.keys() system_user_ids = system_users_with_actions.keys()
system_users = SystemUser.objects.filter(id__in=system_users_id)\ system_users = SystemUser.objects.filter(id__in=system_user_ids)\
.only(*self.serializer_class.Meta.only_fields) \ .only(*self.serializer_class.Meta.only_fields) \
.order_by('priority') .order_by('priority')
system_users = list(system_users) system_users = list(system_users)

View File

@ -52,8 +52,8 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
data.extend(self.serialize_assets(favorite_assets)) data.extend(self.serialize_assets(favorite_assets))
@timeit @timeit
def add_node_filtered_by_system_user(self, data: list, user, asset_perms_id): def add_node_filtered_by_system_user(self, data: list, user, asset_perm_ids):
utils = UserGrantedTreeBuildUtils(user, asset_perms_id) utils = UserGrantedTreeBuildUtils(user, asset_perm_ids)
nodes = utils.get_whole_tree_nodes() nodes = utils.get_whole_tree_nodes()
data.extend(self.serialize_nodes(nodes, with_asset_amount=True)) data.extend(self.serialize_nodes(nodes, with_asset_amount=True))
@ -77,23 +77,23 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
user = request.user user = request.user
data = [] data = []
asset_perms_id = get_user_all_asset_perm_ids(user) asset_perm_ids = get_user_all_asset_perm_ids(user)
system_user_id = request.query_params.get('system_user') system_user_id = request.query_params.get('system_user')
if system_user_id: if system_user_id:
asset_perms_id = list(AssetPermission.objects.valid().filter( asset_perm_ids = list(AssetPermission.objects.valid().filter(
id__in=asset_perms_id, system_users__id=system_user_id, actions__gt=0 id__in=asset_perm_ids, system_users__id=system_user_id, actions__gt=0
).values_list('id', flat=True).distinct()) ).values_list('id', flat=True).distinct())
nodes_query_utils = UserGrantedNodesQueryUtils(user, asset_perms_id) nodes_query_utils = UserGrantedNodesQueryUtils(user, asset_perm_ids)
assets_query_utils = UserGrantedAssetsQueryUtils(user, asset_perms_id) assets_query_utils = UserGrantedAssetsQueryUtils(user, asset_perm_ids)
self.add_ungrouped_resource(data, nodes_query_utils, assets_query_utils) self.add_ungrouped_resource(data, nodes_query_utils, assets_query_utils)
self.add_favorite_resource(data, nodes_query_utils, assets_query_utils) self.add_favorite_resource(data, nodes_query_utils, assets_query_utils)
if system_user_id: if system_user_id:
# 有系统用户筛选的需要重新计算树结构 # 有系统用户筛选的需要重新计算树结构
self.add_node_filtered_by_system_user(data, user, asset_perms_id) self.add_node_filtered_by_system_user(data, user, asset_perm_ids)
else: else:
all_nodes = nodes_query_utils.get_whole_tree_nodes(with_special=False) all_nodes = nodes_query_utils.get_whole_tree_nodes(with_special=False)
data.extend(self.serialize_nodes(all_nodes, with_asset_amount=True)) data.extend(self.serialize_nodes(all_nodes, with_asset_amount=True))

View File

@ -16,9 +16,9 @@ class SystemUserPermission(generics.ListAPIView):
def get_queryset(self): def get_queryset(self):
user = self.request.user user = self.request.user
asset_perms_id = get_user_all_asset_perm_ids(user) asset_perm_ids = get_user_all_asset_perm_ids(user)
queryset = SystemUser.objects.filter( queryset = SystemUser.objects.filter(
granted_by_permissions__id__in=asset_perms_id granted_by_permissions__id__in=asset_perm_ids
).distinct() ).distinct()
return queryset return queryset

View File

@ -65,9 +65,9 @@ class ApplicationPermission(BasePermission):
return self.system_users.count() return self.system_users.count()
def get_all_users(self): def get_all_users(self):
users_id = self.users.all().values_list('id', flat=True) user_ids = self.users.all().values_list('id', flat=True)
user_groups_id = self.user_groups.all().values_list('id', flat=True) user_group_ids = self.user_groups.all().values_list('id', flat=True)
users = User.objects.filter( users = User.objects.filter(
Q(id__in=users_id) | Q(groups__id__in=user_groups_id) Q(id__in=user_ids) | Q(groups__id__in=user_group_ids)
) )
return users return users

View File

@ -137,10 +137,10 @@ class AssetPermission(BasePermission):
def get_all_assets(self): def get_all_assets(self):
from assets.models import Node from assets.models import Node
nodes_keys = self.nodes.all().values_list('key', flat=True) nodes_keys = self.nodes.all().values_list('key', flat=True)
assets_ids = set(self.assets.all().values_list('id', flat=True)) asset_ids = set(self.assets.all().values_list('id', flat=True))
nodes_assets_ids = Node.get_nodes_all_assets_ids_by_keys(nodes_keys) nodes_asset_ids = Node.get_nodes_all_asset_ids_by_keys(nodes_keys)
assets_ids.update(nodes_assets_ids) asset_ids.update(nodes_asset_ids)
assets = Asset.objects.filter(id__in=assets_ids) assets = Asset.objects.filter(id__in=asset_ids)
return assets return assets

View File

@ -99,14 +99,14 @@ class BasePermission(OrgModelMixin):
def get_all_users(self): def get_all_users(self):
from users.models import User from users.models import User
users_id = self.users.all().values_list('id', flat=True) user_ids = self.users.all().values_list('id', flat=True)
groups_id = self.user_groups.all().values_list('id', flat=True) group_ids = self.user_groups.all().values_list('id', flat=True)
users_id = list(users_id) user_ids = list(user_ids)
groups_id = list(groups_id) group_ids = list(group_ids)
qs1 = User.objects.filter(id__in=users_id).distinct() qs1 = User.objects.filter(id__in=user_ids).distinct()
qs2 = User.objects.filter(groups__id__in=groups_id).distinct() qs2 = User.objects.filter(groups__id__in=group_ids).distinct()
qs = UnionQuerySet(qs1, qs2) qs = UnionQuerySet(qs1, qs2)
return qs return qs

View File

@ -27,15 +27,15 @@ def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs):
if not reverse: if not reverse:
# 一个用户添加了多个用户组 # 一个用户添加了多个用户组
users_id = [instance.id] user_ids = [instance.id]
system_users = SystemUser.objects.filter(groups__id__in=pk_set).distinct() system_users = SystemUser.objects.filter(groups__id__in=pk_set).distinct()
else: else:
# 一个用户组添加了多个用户 # 一个用户组添加了多个用户
users_id = pk_set user_ids = pk_set
system_users = SystemUser.objects.filter(groups__id=instance.pk).distinct() system_users = SystemUser.objects.filter(groups__id=instance.pk).distinct()
for system_user in system_users: for system_user in system_users:
system_user.users.add(*users_id) system_user.users.add(*user_ids)
@receiver(m2m_changed, sender=AssetPermission.nodes.through) @receiver(m2m_changed, sender=AssetPermission.nodes.through)
@ -139,17 +139,17 @@ def on_application_permission_system_users_changed(sender, instance: Application
logger.debug("Application permission system_users change signal received") logger.debug("Application permission system_users change signal received")
attrs = instance.applications.all().values_list('attrs', flat=True) attrs = instance.applications.all().values_list('attrs', flat=True)
assets_id = [attr['asset'] for attr in attrs if attr.get('asset')] asset_ids = [attr['asset'] for attr in attrs if attr.get('asset')]
if not assets_id: if not asset_ids:
return return
for system_user in system_users: for system_user in system_users:
system_user.assets.add(*assets_id) system_user.assets.add(*asset_ids)
if system_user.username_same_with_user: if system_user.username_same_with_user:
users_id = instance.users.all().values_list('id', flat=True) user_ids = instance.users.all().values_list('id', flat=True)
groups_id = instance.user_groups.all().values_list('id', flat=True) group_ids = instance.user_groups.all().values_list('id', flat=True)
system_user.groups.add(*groups_id) system_user.groups.add(*group_ids)
system_user.users.add(*users_id) system_user.users.add(*user_ids)
@receiver(m2m_changed, sender=ApplicationPermission.users.through) @receiver(m2m_changed, sender=ApplicationPermission.users.through)
@ -164,12 +164,12 @@ def on_application_permission_users_changed(sender, instance, action, reverse, p
return return
logger.debug("Application permission users change signal received") logger.debug("Application permission users change signal received")
users_id = User.objects.filter(pk__in=pk_set).values_list('id', flat=True) user_ids = User.objects.filter(pk__in=pk_set).values_list('id', flat=True)
system_users = instance.system_users.all() system_users = instance.system_users.all()
for system_user in system_users: for system_user in system_users:
if system_user.username_same_with_user: if system_user.username_same_with_user:
system_user.users.add(*users_id) system_user.users.add(*user_ids)
@receiver(m2m_changed, sender=ApplicationPermission.user_groups.through) @receiver(m2m_changed, sender=ApplicationPermission.user_groups.through)
@ -182,12 +182,12 @@ def on_application_permission_user_groups_changed(sender, instance, action, reve
return return
logger.debug("Application permission user groups change signal received") logger.debug("Application permission user groups change signal received")
groups_id = UserGroup.objects.filter(pk__in=pk_set).values_list('id', flat=True) group_ids = UserGroup.objects.filter(pk__in=pk_set).values_list('id', flat=True)
system_users = instance.system_users.all() system_users = instance.system_users.all()
for system_user in system_users: for system_user in system_users:
if system_user.username_same_with_user: if system_user.username_same_with_user:
system_user.groups.add(*groups_id) system_user.groups.add(*group_ids)
@receiver(m2m_changed, sender=ApplicationPermission.applications.through) @receiver(m2m_changed, sender=ApplicationPermission.applications.through)
@ -202,11 +202,11 @@ def on_application_permission_applications_changed(sender, instance, action, rev
return return
attrs = Application.objects.filter(id__in=pk_set).values_list('attrs', flat=True) attrs = Application.objects.filter(id__in=pk_set).values_list('attrs', flat=True)
assets_id = [attr['asset'] for attr in attrs if attr.get('asset')] asset_ids = [attr['asset'] for attr in attrs if attr.get('asset')]
if not assets_id: if not asset_ids:
return return
system_users = instance.system_users.all() system_users = instance.system_users.all()
for system_user in system_users: for system_user in system_users:
system_user.assets.add(*assets_id) system_user.assets.add(*asset_ids)

View File

@ -6,7 +6,7 @@ from perms.models import ApplicationPermission
logger = get_logger(__file__) logger = get_logger(__file__)
def get_application_system_users_id(user, application): def get_application_system_user_ids(user, application):
queryset = ApplicationPermission.objects.valid()\ queryset = ApplicationPermission.objects.valid()\
.filter( .filter(
Q(users=user) | Q(user_groups__users=user), Q(users=user) | Q(user_groups__users=user),
@ -16,5 +16,5 @@ def get_application_system_users_id(user, application):
def has_application_system_permission(user, application, system_user): def has_application_system_permission(user, application, system_user):
system_users_id = get_application_system_users_id(user, application) system_user_ids = get_application_system_user_ids(user, application)
return system_user.id in system_users_id return system_user.id in system_user_ids

View File

@ -3,7 +3,7 @@ from perms.models import ApplicationPermission
from applications.models import Application from applications.models import Application
def get_user_all_applicationpermissions_id(user): def get_user_all_applicationpermission_ids(user):
application_perm_ids = ApplicationPermission.objects.valid().filter( application_perm_ids = ApplicationPermission.objects.valid().filter(
Q(users=user) | Q(user_groups__users=user) Q(users=user) | Q(user_groups__users=user)
).distinct().values_list('id', flat=True) ).distinct().values_list('id', flat=True)
@ -11,8 +11,8 @@ def get_user_all_applicationpermissions_id(user):
def get_user_granted_all_applications(user): def get_user_granted_all_applications(user):
application_perms_id = get_user_all_applicationpermissions_id(user) application_perm_ids = get_user_all_applicationpermission_ids(user)
applications = Application.objects.filter( applications = Application.objects.filter(
granted_by_permissions__id__in=application_perms_id granted_by_permissions__id__in=application_perm_ids
).distinct() ).distinct()
return applications return applications

View File

@ -5,13 +5,12 @@ from django.db.models import Q
from common.utils import get_logger from common.utils import get_logger
from perms.models import AssetPermission from perms.models import AssetPermission
from perms.hands import Asset, User, UserGroup, SystemUser from perms.hands import Asset, User, UserGroup, SystemUser
from perms.models.base import BasePermissionQuerySet
from perms.utils.asset.user_permission import get_user_all_asset_perm_ids from perms.utils.asset.user_permission import get_user_all_asset_perm_ids
logger = get_logger(__file__) logger = get_logger(__file__)
def get_asset_system_users_id_with_actions(asset_perm_ids, asset: Asset): def get_asset_system_user_ids_with_actions(asset_perm_ids, asset: Asset):
nodes = asset.get_nodes() nodes = asset.get_nodes()
node_keys = set() node_keys = set()
for node in nodes: for node in nodes:
@ -34,21 +33,21 @@ def get_asset_system_users_id_with_actions(asset_perm_ids, asset: Asset):
return system_users_actions return system_users_actions
def get_asset_system_users_id_with_actions_by_user(user: User, asset: Asset): def get_asset_system_user_ids_with_actions_by_user(user: User, asset: Asset):
asset_perm_ids = get_user_all_asset_perm_ids(user) asset_perm_ids = get_user_all_asset_perm_ids(user)
return get_asset_system_users_id_with_actions(asset_perm_ids, asset) return get_asset_system_user_ids_with_actions(asset_perm_ids, asset)
def has_asset_system_permission(user: User, asset: Asset, system_user: SystemUser): def has_asset_system_permission(user: User, asset: Asset, system_user: SystemUser):
systemuser_actions_mapper = get_asset_system_users_id_with_actions_by_user(user, asset) systemuser_actions_mapper = get_asset_system_user_ids_with_actions_by_user(user, asset)
actions = systemuser_actions_mapper.get(system_user.id, []) actions = systemuser_actions_mapper.get(system_user.id, [])
if actions: if actions:
return True return True
return False return False
def get_asset_system_users_id_with_actions_by_group(group: UserGroup, asset: Asset): def get_asset_system_user_ids_with_actions_by_group(group: UserGroup, asset: Asset):
asset_perm_ids = AssetPermission.objects.filter( asset_perm_ids = AssetPermission.objects.filter(
user_groups=group user_groups=group
).valid().values_list('id', flat=True).distinct() ).valid().values_list('id', flat=True).distinct()
return get_asset_system_users_id_with_actions(asset_perm_ids, asset) return get_asset_system_user_ids_with_actions(asset_perm_ids, asset)

View File

@ -70,43 +70,43 @@ class UserGrantedTreeRefreshController:
return {org_id.decode() for org_id in org_ids} return {org_id.decode() for org_id in org_ids}
def set_all_orgs_as_builed(self): def set_all_orgs_as_builed(self):
self.client.sadd(self.key, *self.orgs_id) self.client.sadd(self.key, *self.org_ids)
def have_need_refresh_orgs(self): def have_need_refresh_orgs(self):
builded_org_ids = self.client.smembers(self.key) builded_org_ids = self.client.smembers(self.key)
builded_org_ids = {org_id.decode() for org_id in builded_org_ids} builded_org_ids = {org_id.decode() for org_id in builded_org_ids}
have = self.orgs_id - builded_org_ids have = self.org_ids - builded_org_ids
return have return have
def get_need_refresh_orgs_and_fill_up(self): def get_need_refresh_orgs_and_fill_up(self):
orgs_id = self.orgs_id org_ids = self.org_ids
with self.client.pipeline() as p: with self.client.pipeline() as p:
p.smembers(self.key) p.smembers(self.key)
p.sadd(self.key, *orgs_id) p.sadd(self.key, *org_ids)
ret = p.execute() ret = p.execute()
builded_orgs_id = {org_id.decode() for org_id in ret[0]} builded_org_ids = {org_id.decode() for org_id in ret[0]}
ids = orgs_id - builded_orgs_id ids = org_ids - builded_org_ids
orgs = {*Organization.objects.filter(id__in=ids)} orgs = {*Organization.objects.filter(id__in=ids)}
logger.info(f'Need rebuild orgs are {orgs}, builed orgs are {ret[0]}, all orgs are {orgs_id}') logger.info(f'Need rebuild orgs are {orgs}, builed orgs are {ret[0]}, all orgs are {org_ids}')
return orgs return orgs
@classmethod @classmethod
@on_transaction_commit @on_transaction_commit
def remove_builed_orgs_from_users(cls, orgs_id, users_id): def remove_builed_orgs_from_users(cls, org_ids, user_ids):
client = cls.get_redis_client() client = cls.get_redis_client()
org_ids = [str(org_id) for org_id in orgs_id] org_ids = [str(org_id) for org_id in org_ids]
with client.pipeline() as p: with client.pipeline() as p:
for user_id in users_id: for user_id in user_ids:
key = cls.key_template.format(user_id=user_id) key = cls.key_template.format(user_id=user_id)
p.srem(key, *org_ids) p.srem(key, *org_ids)
p.execute() p.execute()
logger.info(f'Remove orgs from users builded tree: users:{users_id} orgs:{orgs_id}') logger.info(f'Remove orgs from users builded tree: users:{user_ids} orgs:{org_ids}')
@classmethod @classmethod
def add_need_refresh_orgs_for_users(cls, orgs_id, users_id): def add_need_refresh_orgs_for_users(cls, org_ids, user_ids):
cls.remove_builed_orgs_from_users(orgs_id, users_id) cls.remove_builed_orgs_from_users(org_ids, user_ids)
@classmethod @classmethod
@ensure_in_real_or_default_org @ensure_in_real_or_default_org
@ -127,15 +127,15 @@ class UserGrantedTreeRefreshController:
ancestor_id = PermNode.objects.filter(key__in=ancestor_node_keys).values_list('id', flat=True) ancestor_id = PermNode.objects.filter(key__in=ancestor_node_keys).values_list('id', flat=True)
node_ids.update(ancestor_id) node_ids.update(ancestor_id)
assets_related_perms_id = AssetPermission.nodes.through.objects.filter( assets_related_perm_ids = AssetPermission.nodes.through.objects.filter(
node_id__in=node_ids node_id__in=node_ids
).values_list('assetpermission_id', flat=True) ).values_list('assetpermission_id', flat=True)
asset_perm_ids.update(assets_related_perms_id) asset_perm_ids.update(assets_related_perm_ids)
nodes_related_perms_id = AssetPermission.assets.through.objects.filter( nodes_related_perm_ids = AssetPermission.assets.through.objects.filter(
asset_id__in=asset_ids asset_id__in=asset_ids
).values_list('assetpermission_id', flat=True) ).values_list('assetpermission_id', flat=True)
asset_perm_ids.update(nodes_related_perms_id) asset_perm_ids.update(nodes_related_perm_ids)
cls.add_need_refresh_by_asset_perm_ids(asset_perm_ids) cls.add_need_refresh_by_asset_perm_ids(asset_perm_ids)
@ -173,7 +173,7 @@ class UserGrantedTreeRefreshController:
) )
@lazyproperty @lazyproperty
def orgs_id(self): def org_ids(self):
ret = {str(org.id) for org in self.orgs} ret = {str(org.id) for org in self.orgs}
return ret return ret
@ -187,7 +187,7 @@ class UserGrantedTreeRefreshController:
user = self.user user = self.user
with tmp_to_root_org(): with tmp_to_root_org():
UserAssetGrantedTreeNodeRelation.objects.filter(user=user).exclude(org_id__in=self.orgs_id).delete() UserAssetGrantedTreeNodeRelation.objects.filter(user=user).exclude(org_id__in=self.org_ids).delete()
if force or self.have_need_refresh_orgs(): if force or self.have_need_refresh_orgs():
with UserGrantedTreeRebuildLock(user_id=user.id): with UserGrantedTreeRebuildLock(user_id=user.id):
@ -295,10 +295,10 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
# 查询授权资产关联的节点设置 # 查询授权资产关联的节点设置
def process_direct_granted_assets(): def process_direct_granted_assets():
# 查询直接授权资产 # 查询直接授权资产
nodes_id = {node_id_str for node_id_str, _ in self.direct_granted_asset_id_node_id_str_pairs} node_ids = {node_id_str for node_id_str, _ in self.direct_granted_asset_id_node_id_str_pairs}
# 查询授权资产关联的节点设置 2.80 # 查询授权资产关联的节点设置 2.80
granted_asset_nodes = PermNode.objects.filter( granted_asset_nodes = PermNode.objects.filter(
id__in=nodes_id id__in=node_ids
).distinct().only(*node_only_fields) ).distinct().only(*node_only_fields)
granted_asset_nodes = list(granted_asset_nodes) granted_asset_nodes = list(granted_asset_nodes)
@ -350,11 +350,11 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
UserAssetGrantedTreeNodeRelation.objects.bulk_create(to_create) UserAssetGrantedTreeNodeRelation.objects.bulk_create(to_create)
@timeit @timeit
def _fill_direct_granted_node_assets_id_from_mem(self, nodes_key, mapper): def _fill_direct_granted_node_asset_ids_from_mem(self, nodes_key, mapper):
org_id = current_org.id org_id = current_org.id
for key in nodes_key: for key in nodes_key:
assets_id = PermNode.get_all_assets_id_by_node_key(org_id, key) asset_ids = PermNode.get_all_asset_ids_by_node_key(org_id, key)
mapper[key].update(assets_id) mapper[key].update(asset_ids)
@lazyproperty @lazyproperty
def direct_granted_asset_id_node_id_str_pairs(self): def direct_granted_asset_id_node_id_str_pairs(self):
@ -379,7 +379,7 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
node = nodes[0] node = nodes[0]
if node.node_from == NodeFrom.granted and node.key.isdigit(): if node.node_from == NodeFrom.granted and node.key.isdigit():
with tmp_to_org(node.org): with tmp_to_org(node.org):
node.granted_assets_amount = len(node.get_all_assets_id()) node.granted_assets_amount = len(node.get_all_asset_ids())
return return
direct_granted_nodes_key = [] direct_granted_nodes_key = []
@ -392,7 +392,7 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
# 授权的节点和直接资产的映射 # 授权的节点和直接资产的映射
nodekey_assetsid_mapper = defaultdict(set) nodekey_assetsid_mapper = defaultdict(set)
# 直接授权的节点,资产从完整树过来 # 直接授权的节点,资产从完整树过来
self._fill_direct_granted_node_assets_id_from_mem( self._fill_direct_granted_node_asset_ids_from_mem(
direct_granted_nodes_key, nodekey_assetsid_mapper direct_granted_nodes_key, nodekey_assetsid_mapper
) )

View File

@ -86,13 +86,13 @@ class StatusViewSet(viewsets.ModelViewSet):
return Response(serializer.data, status=201) return Response(serializer.data, status=201)
def handle_sessions(self): def handle_sessions(self):
sessions_id = self.request.data.get('sessions', []) session_ids = self.request.data.get('sessions', [])
# guacamole 上报的 session 是字符串 # guacamole 上报的 session 是字符串
# "[53cd3e47-210f-41d8-b3c6-a184f3, 53cd3e47-210f-41d8-b3c6-a184f4]" # "[53cd3e47-210f-41d8-b3c6-a184f3, 53cd3e47-210f-41d8-b3c6-a184f4]"
if isinstance(sessions_id, str): if isinstance(session_ids, str):
sessions_id = sessions_id[1:-1].split(',') session_ids = session_ids[1:-1].split(',')
sessions_id = [sid.strip() for sid in sessions_id if sid.strip()] session_ids = [sid.strip() for sid in session_ids if sid.strip()]
Session.set_sessions_active(sessions_id) Session.set_sessions_active(session_ids)
def get_queryset(self): def get_queryset(self):
terminal_id = self.kwargs.get("terminal", None) terminal_id = self.kwargs.get("terminal", None)

View File

@ -137,8 +137,8 @@ class Session(OrgModelMixin):
return name, None return name, None
@classmethod @classmethod
def set_sessions_active(cls, sessions_id): def set_sessions_active(cls, session_ids):
data = {cls.ACTIVE_CACHE_KEY_PREFIX.format(i): i for i in sessions_id} data = {cls.ACTIVE_CACHE_KEY_PREFIX.format(i): i for i in session_ids}
cache.set_many(data, timeout=5*60) cache.set_many(data, timeout=5*60)
@classmethod @classmethod

View File

@ -26,11 +26,11 @@ class Handler(BaseHandler):
def _construct_meta_display_of_approve(self): def _construct_meta_display_of_approve(self):
meta_display_fields = ['approve_applications_display', 'approve_system_users_display'] meta_display_fields = ['approve_applications_display', 'approve_system_users_display']
approve_applications_id = self.ticket.meta.get('approve_applications', []) approve_application_ids = self.ticket.meta.get('approve_applications', [])
approve_system_users_id = self.ticket.meta.get('approve_system_users', []) approve_system_user_ids = self.ticket.meta.get('approve_system_users', [])
with tmp_to_org(self.ticket.org_id): with tmp_to_org(self.ticket.org_id):
approve_applications = Application.objects.filter(id__in=approve_applications_id) approve_applications = Application.objects.filter(id__in=approve_application_ids)
system_users = SystemUser.objects.filter(id__in=approve_system_users_id) system_users = SystemUser.objects.filter(id__in=approve_system_user_ids)
approve_applications_display = [str(application) for application in approve_applications] approve_applications_display = [str(application) for application in approve_applications]
approve_system_users_display = [str(system_user) for system_user in system_users] approve_system_users_display = [str(system_user) for system_user in system_users]
meta_display_values = [approve_applications_display, approve_system_users_display] meta_display_values = [approve_applications_display, approve_system_users_display]
@ -89,8 +89,8 @@ class Handler(BaseHandler):
apply_category = self.ticket.meta.get('apply_category') apply_category = self.ticket.meta.get('apply_category')
apply_type = self.ticket.meta.get('apply_type') apply_type = self.ticket.meta.get('apply_type')
approve_permission_name = self.ticket.meta.get('approve_permission_name', '') approve_permission_name = self.ticket.meta.get('approve_permission_name', '')
approved_applications_id = self.ticket.meta.get('approve_applications', []) approved_application_ids = self.ticket.meta.get('approve_applications', [])
approve_system_users_id = self.ticket.meta.get('approve_system_users', []) approve_system_user_ids = self.ticket.meta.get('approve_system_users', [])
approve_date_start = self.ticket.meta.get('approve_date_start') approve_date_start = self.ticket.meta.get('approve_date_start')
approve_date_expired = self.ticket.meta.get('approve_date_expired') approve_date_expired = self.ticket.meta.get('approve_date_expired')
permission_created_by = '{}:{}'.format( permission_created_by = '{}:{}'.format(
@ -121,7 +121,7 @@ class Handler(BaseHandler):
with tmp_to_org(self.ticket.org_id): with tmp_to_org(self.ticket.org_id):
application_permission = ApplicationPermission.objects.create(**permissions_data) application_permission = ApplicationPermission.objects.create(**permissions_data)
application_permission.users.add(self.ticket.applicant) application_permission.users.add(self.ticket.applicant)
application_permission.applications.set(approved_applications_id) application_permission.applications.set(approved_application_ids)
application_permission.system_users.set(approve_system_users_id) application_permission.system_users.set(approve_system_user_ids)
return application_permission return application_permission

View File

@ -27,11 +27,11 @@ class Handler(BaseHandler):
] ]
approve_actions = self.ticket.meta.get('approve_actions', Action.NONE) approve_actions = self.ticket.meta.get('approve_actions', Action.NONE)
approve_actions_display = Action.value_to_choices_display(approve_actions) approve_actions_display = Action.value_to_choices_display(approve_actions)
approve_assets_id = self.ticket.meta.get('approve_assets', []) approve_asset_ids = self.ticket.meta.get('approve_assets', [])
approve_system_users_id = self.ticket.meta.get('approve_system_users', []) approve_system_user_ids = self.ticket.meta.get('approve_system_users', [])
with tmp_to_org(self.ticket.org_id): with tmp_to_org(self.ticket.org_id):
assets = Asset.objects.filter(id__in=approve_assets_id) assets = Asset.objects.filter(id__in=approve_asset_ids)
system_users = SystemUser.objects.filter(id__in=approve_system_users_id) system_users = SystemUser.objects.filter(id__in=approve_system_user_ids)
approve_assets_display = [str(asset) for asset in assets] approve_assets_display = [str(asset) for asset in assets]
approve_system_users_display = [str(system_user) for system_user in system_users] approve_system_users_display = [str(system_user) for system_user in system_users]
meta_display_values = [ meta_display_values = [
@ -91,8 +91,8 @@ class Handler(BaseHandler):
return asset_permission return asset_permission
approve_permission_name = self.ticket.meta.get('approve_permission_name', ) approve_permission_name = self.ticket.meta.get('approve_permission_name', )
approve_assets_id = self.ticket.meta.get('approve_assets', []) approve_asset_ids = self.ticket.meta.get('approve_assets', [])
approve_system_users_id = self.ticket.meta.get('approve_system_users', []) approve_system_user_ids = self.ticket.meta.get('approve_system_users', [])
approve_actions = self.ticket.meta.get('approve_actions', Action.NONE) approve_actions = self.ticket.meta.get('approve_actions', Action.NONE)
approve_date_start = self.ticket.meta.get('approve_date_start') approve_date_start = self.ticket.meta.get('approve_date_start')
approve_date_expired = self.ticket.meta.get('approve_date_expired') approve_date_expired = self.ticket.meta.get('approve_date_expired')
@ -124,7 +124,7 @@ class Handler(BaseHandler):
with tmp_to_org(self.ticket.org_id): with tmp_to_org(self.ticket.org_id):
asset_permission = AssetPermission.objects.create(**permission_data) asset_permission = AssetPermission.objects.create(**permission_data)
asset_permission.users.add(self.ticket.applicant) asset_permission.users.add(self.ticket.applicant)
asset_permission.assets.set(approve_assets_id) asset_permission.assets.set(approve_asset_ids)
asset_permission.system_users.set(approve_system_users_id) asset_permission.system_users.set(approve_system_user_ids)
return asset_permission return asset_permission

View File

@ -98,10 +98,10 @@ class ApproveSerializer(serializers.Serializer):
apply_type = self.root.instance.meta.get('apply_type') apply_type = self.root.instance.meta.get('apply_type')
queries = Q(type=apply_type) queries = Q(type=apply_type)
queries &= Q(id__in=approve_applications) queries &= Q(id__in=approve_applications)
applications_id = Application.objects.filter(queries).values_list('id', flat=True) application_ids = Application.objects.filter(queries).values_list('id', flat=True)
applications_id = [str(application_id) for application_id in applications_id] application_ids = [str(application_id) for application_id in application_ids]
if applications_id: if application_ids:
return applications_id return application_ids
raise serializers.ValidationError(_( raise serializers.ValidationError(_(
'No `Application` are found under Organization `{}`'.format(self.root.instance.org_name) 'No `Application` are found under Organization `{}`'.format(self.root.instance.org_name)
@ -116,10 +116,10 @@ class ApproveSerializer(serializers.Serializer):
protocol = SystemUser.get_protocol_by_application_type(apply_type) protocol = SystemUser.get_protocol_by_application_type(apply_type)
queries = Q(protocol=protocol) queries = Q(protocol=protocol)
queries &= Q(id__in=approve_system_users) queries &= Q(id__in=approve_system_users)
system_users_id = SystemUser.objects.filter(queries).values_list('id', flat=True) system_user_ids = SystemUser.objects.filter(queries).values_list('id', flat=True)
system_users_id = [str(system_user_id) for system_user_id in system_users_id] system_user_ids = [str(system_user_id) for system_user_id in system_user_ids]
if system_users_id: if system_user_ids:
return system_users_id return system_user_ids
raise serializers.ValidationError(_( raise serializers.ValidationError(_(
'No `SystemUser` are found under Organization `{}`'.format(self.root.instance.org_name) 'No `SystemUser` are found under Organization `{}`'.format(self.root.instance.org_name)
@ -146,9 +146,9 @@ class ApplyApplicationSerializer(ApplySerializer, ApproveSerializer):
queries &= Q(type=apply_type) queries &= Q(type=apply_type)
with tmp_to_org(self.root.instance.org_id): with tmp_to_org(self.root.instance.org_id):
applications_id = Application.objects.filter(queries).values_list('id', flat=True)[:5] application_ids = Application.objects.filter(queries).values_list('id', flat=True)[:5]
applications_id = [str(application_id) for application_id in applications_id] application_ids = [str(application_id) for application_id in application_ids]
return applications_id return application_ids
def get_recommend_system_users(self, value): def get_recommend_system_users(self, value):
if not isinstance(self.root.instance, Ticket): if not isinstance(self.root.instance, Ticket):
@ -167,6 +167,6 @@ class ApplyApplicationSerializer(ApplySerializer, ApproveSerializer):
queries &= Q(protocol=protocol) queries &= Q(protocol=protocol)
with tmp_to_org(self.root.instance.org_id): with tmp_to_org(self.root.instance.org_id):
system_users_id = SystemUser.objects.filter(queries).values_list('id', flat=True)[:5] system_user_ids = SystemUser.objects.filter(queries).values_list('id', flat=True)[:5]
system_users_id = [str(system_user_id) for system_user_id in system_users_id] system_user_ids = [str(system_user_id) for system_user_id in system_user_ids]
return system_users_id return system_user_ids

View File

@ -99,10 +99,10 @@ class ApproveSerializer(serializers.Serializer):
return [] return []
with tmp_to_org(self.root.instance.org_id): with tmp_to_org(self.root.instance.org_id):
assets_id = Asset.objects.filter(id__in=approve_assets).values_list('id', flat=True) asset_ids = Asset.objects.filter(id__in=approve_assets).values_list('id', flat=True)
assets_id = [str(asset_id) for asset_id in assets_id] asset_ids = [str(asset_id) for asset_id in asset_ids]
if assets_id: if asset_ids:
return assets_id return asset_ids
raise serializers.ValidationError(_( raise serializers.ValidationError(_(
'No `Asset` are found under Organization `{}`'.format(self.root.instance.org_name) 'No `Asset` are found under Organization `{}`'.format(self.root.instance.org_name)
@ -115,10 +115,10 @@ class ApproveSerializer(serializers.Serializer):
with tmp_to_org(self.root.instance.org_id): with tmp_to_org(self.root.instance.org_id):
queries = Q(protocol__in=SystemUser.ASSET_CATEGORY_PROTOCOLS) queries = Q(protocol__in=SystemUser.ASSET_CATEGORY_PROTOCOLS)
queries &= Q(id__in=approve_system_users) queries &= Q(id__in=approve_system_users)
system_users_id = SystemUser.objects.filter(queries).values_list('id', flat=True) system_user_ids = SystemUser.objects.filter(queries).values_list('id', flat=True)
system_users_id = [str(system_user_id) for system_user_id in system_users_id] system_user_ids = [str(system_user_id) for system_user_id in system_user_ids]
if system_users_id: if system_user_ids:
return system_users_id return system_user_ids
raise serializers.ValidationError(_( raise serializers.ValidationError(_(
'No `SystemUser` are found under Organization `{}`'.format(self.root.instance.org_name) 'No `SystemUser` are found under Organization `{}`'.format(self.root.instance.org_name)
@ -144,9 +144,9 @@ class ApplyAssetSerializer(ApplySerializer, ApproveSerializer):
if not queries: if not queries:
return [] return []
with tmp_to_org(self.root.instance.org_id): with tmp_to_org(self.root.instance.org_id):
assets_id = Asset.objects.filter(queries).values_list('id', flat=True)[:5] asset_ids = Asset.objects.filter(queries).values_list('id', flat=True)[:5]
assets_id = [str(asset_id) for asset_id in assets_id] asset_ids = [str(asset_id) for asset_id in asset_ids]
return assets_id return asset_ids
def get_recommend_system_users(self, value): def get_recommend_system_users(self, value):
if not isinstance(self.root.instance, Ticket): if not isinstance(self.root.instance, Ticket):
@ -163,6 +163,6 @@ class ApplyAssetSerializer(ApplySerializer, ApproveSerializer):
queries &= Q(protocol__in=SystemUser.ASSET_CATEGORY_PROTOCOLS) queries &= Q(protocol__in=SystemUser.ASSET_CATEGORY_PROTOCOLS)
with tmp_to_org(self.root.instance.org_id): with tmp_to_org(self.root.instance.org_id):
system_users_id = SystemUser.objects.filter(queries).values_list('id', flat=True)[:5] system_user_ids = SystemUser.objects.filter(queries).values_list('id', flat=True)[:5]
system_users_id = [str(system_user_id) for system_user_id in system_users_id] system_user_ids = [str(system_user_id) for system_user_id in system_user_ids]
return system_users_id return system_user_ids

View File

@ -123,10 +123,10 @@ class UserViewSet(CommonApiMixin, UserQuerysetMixin, BulkModelViewSet):
def perform_bulk_update(self, serializer): def perform_bulk_update(self, serializer):
# TODO: 需要测试 # TODO: 需要测试
users_ids = [ user_ids = [
d.get("id") or d.get("pk") for d in serializer.validated_data d.get("id") or d.get("pk") for d in serializer.validated_data
] ]
users = current_org.get_members().filter(id__in=users_ids) users = current_org.get_members().filter(id__in=user_ids)
for user in users: for user in users:
self.check_object_permissions(self.request, user) self.check_object_permissions(self.request, user)
return super().perform_bulk_update(serializer) return super().perform_bulk_update(serializer)

View File

@ -346,11 +346,11 @@ class RoleMixin:
@classmethod @classmethod
def get_super_and_org_admins(cls, org=None): def get_super_and_org_admins(cls, org=None):
super_admins = cls.get_super_admins() super_admins = cls.get_super_admins()
super_admins_id = list(super_admins.values_list('id', flat=True)) super_admin_ids = list(super_admins.values_list('id', flat=True))
org_admins = cls.get_org_admins(org) org_admins = cls.get_org_admins(org)
org_admins_id = list(org_admins.values_list('id', flat=True)) org_admin_ids = list(org_admins.values_list('id', flat=True))
admins_id = set(org_admins_id + super_admins_id) admin_ids = set(org_admin_ids + super_admin_ids)
admins = User.objects.filter(id__in=admins_id) admins = User.objects.filter(id__in=admin_ids)
return admins return admins

View File

@ -57,16 +57,16 @@ class NodesGenerator(FakeDataGenerator):
class AssetsGenerator(FakeDataGenerator): class AssetsGenerator(FakeDataGenerator):
resource = 'asset' resource = 'asset'
admin_users_id: list admin_user_ids: list
nodes_id: list node_ids: list
def pre_generate(self): def pre_generate(self):
self.admin_users_id = list(AdminUser.objects.all().values_list('id', flat=True)) self.admin_user_ids = list(AdminUser.objects.all().values_list('id', flat=True))
self.nodes_id = list(Node.objects.all().values_list('id', flat=True)) self.node_ids = list(Node.objects.all().values_list('id', flat=True))
def set_assets_nodes(self, assets): def set_assets_nodes(self, assets):
for asset in assets: for asset in assets:
nodes_id_add_to = random.sample(self.nodes_id, 3) nodes_id_add_to = random.sample(self.node_ids, 3)
asset.nodes.add(*nodes_id_add_to) asset.nodes.add(*nodes_id_add_to)
def do_generate(self, batch, batch_size): def do_generate(self, batch, batch_size):
@ -79,7 +79,7 @@ class AssetsGenerator(FakeDataGenerator):
data = dict( data = dict(
ip=ip, ip=ip,
hostname=hostname, hostname=hostname,
admin_user_id=choice(self.admin_users_id), admin_user_id=choice(self.admin_user_ids),
created_by='Fake', created_by='Fake',
org_id=self.org.id org_id=self.org.id
) )

View File

@ -10,46 +10,46 @@ from perms.models import *
class AssetPermissionGenerator(FakeDataGenerator): class AssetPermissionGenerator(FakeDataGenerator):
resource = 'asset_permission' resource = 'asset_permission'
users_id: list user_ids: list
user_groups_id: list user_group_ids: list
assets_id: list asset_ids: list
nodes_id: list node_ids: list
system_users_id: list system_user_ids: list
def pre_generate(self): def pre_generate(self):
self.nodes_id = list(Node.objects.all().values_list('id', flat=True)) self.node_ids = list(Node.objects.all().values_list('id', flat=True))
self.assets_id = list(Asset.objects.all().values_list('id', flat=True)) self.asset_ids = list(Asset.objects.all().values_list('id', flat=True))
self.system_users_id = list(SystemUser.objects.all().values_list('id', flat=True)) self.system_user_ids = list(SystemUser.objects.all().values_list('id', flat=True))
self.users_id = list(User.objects.all().values_list('id', flat=True)) self.user_ids = list(User.objects.all().values_list('id', flat=True))
self.user_groups_id = list(UserGroup.objects.all().values_list('id', flat=True)) self.user_group_ids = list(UserGroup.objects.all().values_list('id', flat=True))
def set_users(self, perms): def set_users(self, perms):
through = AssetPermission.users.through through = AssetPermission.users.through
choices = self.users_id choices = self.user_ids
relation_name = 'user_id' relation_name = 'user_id'
self.set_relations(perms, through, relation_name, choices) self.set_relations(perms, through, relation_name, choices)
def set_user_groups(self, perms): def set_user_groups(self, perms):
through = AssetPermission.user_groups.through through = AssetPermission.user_groups.through
choices = self.user_groups_id choices = self.user_group_ids
relation_name = 'usergroup_id' relation_name = 'usergroup_id'
self.set_relations(perms, through, relation_name, choices) self.set_relations(perms, through, relation_name, choices)
def set_assets(self, perms): def set_assets(self, perms):
through = AssetPermission.assets.through through = AssetPermission.assets.through
choices = self.assets_id choices = self.asset_ids
relation_name = 'asset_id' relation_name = 'asset_id'
self.set_relations(perms, through, relation_name, choices) self.set_relations(perms, through, relation_name, choices)
def set_nodes(self, perms): def set_nodes(self, perms):
through = AssetPermission.nodes.through through = AssetPermission.nodes.through
choices = self.nodes_id choices = self.node_ids
relation_name = 'node_id' relation_name = 'node_id'
self.set_relations(perms, through, relation_name, choices) self.set_relations(perms, through, relation_name, choices)
def set_system_users(self, perms): def set_system_users(self, perms):
through = AssetPermission.system_users.through through = AssetPermission.system_users.through
choices = self.system_users_id choices = self.system_user_ids
relation_name = 'systemuser_id' relation_name = 'systemuser_id'
self.set_relations(perms, through, relation_name, choices) self.set_relations(perms, through, relation_name, choices)
@ -59,8 +59,8 @@ class AssetPermissionGenerator(FakeDataGenerator):
for perm in perms: for perm in perms:
if choice_count is None: if choice_count is None:
choice_count = choice(range(8)) choice_count = choice(range(8))
resources_id = sample(choices, choice_count) resource_ids = sample(choices, choice_count)
for rid in resources_id: for rid in resource_ids:
data = {'assetpermission_id': perm.id} data = {'assetpermission_id': perm.id}
data[relation_name] = rid data[relation_name] = rid
relations.append(through(**data)) relations.append(through(**data))

View File

@ -21,11 +21,11 @@ class UserGroupGenerator(FakeDataGenerator):
class UserGenerator(FakeDataGenerator): class UserGenerator(FakeDataGenerator):
resource = 'user' resource = 'user'
roles: list roles: list
groups_id: list group_ids: list
def pre_generate(self): def pre_generate(self):
self.roles = list(dict(User.ROLE.choices).keys()) self.roles = list(dict(User.ROLE.choices).keys())
self.groups_id = list(UserGroup.objects.all().values_list('id', flat=True)) self.group_ids = list(UserGroup.objects.all().values_list('id', flat=True))
def set_org(self, users): def set_org(self, users):
relations = [] relations = []
@ -39,7 +39,7 @@ class UserGenerator(FakeDataGenerator):
def set_groups(self, users): def set_groups(self, users):
relations = [] relations = []
for i in users: for i in users:
groups_to_join = sample(self.groups_id, 3) groups_to_join = sample(self.group_ids, 3)
_relations = [User.groups.through(user_id=i.id, usergroup_id=gid) for gid in groups_to_join] _relations = [User.groups.through(user_id=i.id, usergroup_id=gid) for gid in groups_to_join]
relations.extend(_relations) relations.extend(_relations)
User.groups.through.objects.bulk_create(relations, ignore_conflicts=True) User.groups.through.objects.bulk_create(relations, ignore_conflicts=True)