diff --git a/apps/assets/api/asset.py b/apps/assets/api/asset.py index 1227b01c9..ff5047ba3 100644 --- a/apps/assets/api/asset.py +++ b/apps/assets/api/asset.py @@ -11,8 +11,7 @@ from django.db.models import Q from common.mixins import IDInFilterMixin from common.utils import get_logger -from ..hands import IsSuperUser, IsValidUser, IsSuperUserOrAppUser, \ - NodePermissionUtil +from ..hands import IsSuperUser, IsValidUser, IsSuperUserOrAppUser from ..models import Asset, SystemUser, AdminUser, Node from .. import serializers from ..tasks import update_asset_hardware_info_manual, \ @@ -22,7 +21,7 @@ from ..utils import LabelFilter logger = get_logger(__file__) __all__ = [ - 'AssetViewSet', 'UserAssetListView', 'AssetListUpdateApi', + 'AssetViewSet', 'AssetListUpdateApi', 'AssetRefreshHardwareApi', 'AssetAdminUserTestApi' ] @@ -71,19 +70,6 @@ class AssetViewSet(IDInFilterMixin, LabelFilter, BulkModelViewSet): return queryset -class UserAssetListView(generics.ListAPIView): - queryset = Asset.objects.all() - serializer_class = serializers.AssetSerializer - permission_classes = (IsValidUser,) - - def get_queryset(self): - assets_granted = NodePermissionUtil.get_user_assets(self.request.user).keys() - queryset = self.queryset.filter( - id__in=[asset.id for asset in assets_granted] - ) - return queryset - - class AssetListUpdateApi(IDInFilterMixin, ListBulkCreateUpdateDestroyAPIView): """ Asset bulk update api diff --git a/apps/assets/api/node.py b/apps/assets/api/node.py index a6d00b44a..e5ace021e 100644 --- a/apps/assets/api/node.py +++ b/apps/assets/api/node.py @@ -31,7 +31,7 @@ from .. import serializers logger = get_logger(__file__) __all__ = [ 'NodeViewSet', 'NodeChildrenApi', - 'NodeAssetsApi', 'NodeWithAssetsApi', + 'NodeAssetsApi', 'NodeAddAssetsApi', 'NodeRemoveAssetsApi', 'NodeReplaceAssetsApi', 'NodeAddChildrenApi', 'RefreshNodeHardwareInfoApi', @@ -42,14 +42,7 @@ __all__ = [ class NodeViewSet(BulkModelViewSet): queryset = Node.objects.all() permission_classes = (IsSuperUser,) - # serializer_class = serializers.NodeSerializer - - def get_serializer_class(self): - show_current_asset = self.request.query_params.get('show_current_asset') - if show_current_asset: - return serializers.NodeCurrentSerializer - else: - return serializers.NodeSerializer + serializer_class = serializers.NodeSerializer def perform_create(self, serializer): child_key = Node.root().get_next_child_key() @@ -57,32 +50,32 @@ class NodeViewSet(BulkModelViewSet): serializer.save() -class NodeWithAssetsApi(generics.ListAPIView): - permission_classes = (IsSuperUser,) - serializers = serializers.NodeSerializer - - def get_node(self): - pk = self.kwargs.get('pk') or self.request.query_params.get('node') - if not pk: - node = Node.root() - else: - node = get_object_or_404(Node, pk) - return node - - def get_queryset(self): - queryset = [] - node = self.get_node() - children = node.get_children() - assets = node.get_assets() - queryset.extend(list(children)) - - for asset in assets: - node = Node() - node.id = asset.id - node.parent = node.id - node.value = asset.hostname - queryset.append(node) - return queryset +# class NodeWithAssetsApi(generics.ListAPIView): +# permission_classes = (IsSuperUser,) +# serializers = serializers.NodeSerializer +# +# def get_node(self): +# pk = self.kwargs.get('pk') or self.request.query_params.get('node') +# if not pk: +# node = Node.root() +# else: +# node = get_object_or_404(Node, pk) +# return node +# +# def get_queryset(self): +# queryset = [] +# node = self.get_node() +# children = node.get_children() +# assets = node.get_assets() +# queryset.extend(list(children)) +# +# for asset in assets: +# node = Node() +# node.id = asset.id +# node.parent = node.id +# node.value = asset.hostname +# queryset.append(node) +# return queryset class NodeChildrenApi(mixins.ListModelMixin, generics.CreateAPIView): @@ -147,9 +140,9 @@ class NodeChildrenApi(mixins.ListModelMixin, generics.CreateAPIView): for asset in assets: node_fake = Node() node_fake.id = asset.id - node_fake.parent = node - node_fake.value = asset.hostname node_fake.is_node = False + node_fake.parent_id = node.id + node_fake.value = asset.hostname queryset.append(node_fake) queryset = sorted(queryset, key=lambda x: x.is_node, reverse=True) return queryset @@ -185,7 +178,7 @@ class NodeAddChildrenApi(generics.UpdateAPIView): for node in children: if not node: continue - node.set_parent(instance) + node.parent = instance return Response("OK") diff --git a/apps/assets/hands.py b/apps/assets/hands.py index ad44052d3..a1a376135 100644 --- a/apps/assets/hands.py +++ b/apps/assets/hands.py @@ -14,4 +14,3 @@ from common.mixins import AdminUserRequiredMixin from common.permissions import IsAppUser, IsSuperUser, IsValidUser, IsSuperUserOrAppUser from users.models import User, UserGroup -from perms.utils import NodePermissionUtil diff --git a/apps/assets/models/asset.py b/apps/assets/models/asset.py index 6e6b6c678..a974d3385 100644 --- a/apps/assets/models/asset.py +++ b/apps/assets/models/asset.py @@ -5,6 +5,7 @@ import uuid import logging import random +from functools import reduce from django.db import models from django.utils.translation import ugettext_lazy as _ @@ -149,22 +150,15 @@ class Asset(models.Model): nodes = self.nodes.all() or [Node.root()] return nodes - @property - def nodes_cache_key(self): - key = "NODES_OF_{}".format(str(self.id)) - return key - - def get_nodes_or_cache(self): - cached = cache.get(self.nodes_cache_key) - if cached is not None: - return cached - nodes = list(self.get_nodes()) - cache.set(self.nodes_cache_key, nodes, 3600) + def get_all_nodes(self, flat=False): + nodes = [] + for node in self.get_nodes(): + _nodes = node.get_ancestor(with_self=True) + _nodes.append(_nodes) + if flat: + nodes = list(reduce(lambda x, y: set(x) | set(y), nodes)) return nodes - def expire_nodes_cache(self): - cache.delete(self.nodes_cache_key) - @property def hardware_info(self): if self.cpu_count: diff --git a/apps/assets/models/node.py b/apps/assets/models/node.py index ed712c8f8..4f4f9ad8b 100644 --- a/apps/assets/models/node.py +++ b/apps/assets/models/node.py @@ -5,7 +5,7 @@ import uuid from django.db import models, transaction from django.db.models import Q from django.utils.translation import ugettext_lazy as _ - +from common.utils import with_cache __all__ = ['Node'] @@ -22,32 +22,36 @@ class Node(models.Model): def __str__(self): return self.full_value + def __eq__(self, other): + return self.key == other.key + + def __gt__(self, other): + if self.is_root(): + return True + self_key = [int(k) for k in self.key.split(':')] + other_key = [int(k) for k in other.key.split(':')] + if len(self_key) < len(other_key): + return True + elif len(self_key) > len(other_key): + return False + else: + return self_key[-1] < other_key[-1] + @property def name(self): return self.value @property def full_value(self): - ancestor = [a.value for a in self.ancestor] + ancestor = [a.value for a in self.get_ancestor(with_self=True)] if self.is_root(): return self.value - ancestor.append(self.value) return ' / '.join(ancestor) @property def level(self): return len(self.key.split(':')) - def set_parent(self, instance): - children = self.get_all_children() - old_key = self.key - with transaction.atomic(): - self.parent = instance - for child in children: - child.key = child.key.replace(old_key, self.key, 1) - child.save() - self.save() - def get_next_child_key(self): mark = self.child_mark self.child_mark += 1 @@ -55,32 +59,35 @@ class Node(models.Model): return "{}:{}".format(self.key, mark) def create_child(self, value): - child_key = self.get_next_child_key() - child = self.__class__.objects.create(key=child_key, value=value) - return child + with transaction.atomic(): + child_key = self.get_next_child_key() + child = self.__class__.objects.create(key=child_key, value=value) + return child - def get_children(self): + def get_children(self, with_self=False): + pattern = r'^{0}$|^{}:[0-9]+$' if with_self else r'^{}:[0-9]+$' return self.__class__.objects.filter( - key__regex=r'^{}:[0-9]+$'.format(self.key) + key__regex=pattern.format(self.key) ) - def get_children_with_self(self): + def get_all_children(self, with_self=False): + pattern = r'^{0}$|^{0}:' if with_self else r'^{0}' return self.__class__.objects.filter( - key__regex=r'^{0}$|^{0}:[0-9]+$'.format(self.key) + key__regex=pattern.format(self.key) ) - def get_all_children(self): - return self.__class__.objects.filter( - key__startswith='{}:'.format(self.key) - ) - - def get_all_children_with_self(self): - return self.__class__.objects.filter( - key__regex=r'^{0}$|^{0}:'.format(self.key) + def get_sibling(self, with_self=False): + key = ':'.join(self.key.split(':')[:-1]) + pattern = r'^{}:[0-9]+$'.format(key) + sibling = self.__class__.objects.filter( + key__regex=pattern.format(self.key) ) + if not with_self: + sibling = sibling.exclude(key=self.key) + return sibling def get_family(self): - ancestor = self.ancestor + ancestor = self.get_ancestor() children = self.get_all_children() return [*tuple(ancestor), self, *tuple(children)] @@ -91,7 +98,7 @@ class Node(models.Model): Q(nodes__id=self.id) | Q(nodes__isnull=True) ) else: - assets = Asset.objects.filter(nodes__id=self.id) + assets = self.assets.all() return assets def get_valid_assets(self): @@ -102,8 +109,8 @@ class Node(models.Model): if self.is_root(): assets = Asset.objects.all() else: - nodes = self.get_all_children_with_self() - assets = Asset.objects.filter(nodes__in=nodes).distinct() + pattern = r'^{0}$|^{0}:'.format(self.key) + assets = Asset.objects.filter(nodes__key__regex=pattern) return assets def get_all_valid_assets(self): @@ -125,26 +132,33 @@ class Node(models.Model): @parent.setter def parent(self, parent): - self.key = parent.get_next_child_key() + if self.is_node: + children = self.get_all_children() + old_key = self.key + with transaction.atomic(): + self.key = parent.get_next_child_key() + for child in children: + child.key = child.key.replace(old_key, self.key, 1) + child.save() + self.save() + else: + self.key = parent.key+':fake' - @property - def ancestor(self): + def get_ancestor(self, with_self=False): if self.is_root(): ancestor = self.__class__.objects.filter(key='0') - else: - _key = self.key.split(':') - ancestor_keys = [] - for i in range(len(_key)-1): - _key.pop() - ancestor_keys.append(':'.join(_key)) - ancestor = self.__class__.objects.filter(key__in=ancestor_keys) - ancestor = list(ancestor) - return ancestor + return ancestor - @property - def ancestor_with_self(self): - ancestor = list(self.ancestor) - ancestor.insert(0, self) + _key = self.key.split(':') + if not with_self: + _key.pop() + ancestor_keys = [] + for i in range(len(_key)): + ancestor_keys.append(':'.join(_key)) + _key.pop() + ancestor = self.__class__.objects.filter( + key__in=ancestor_keys + ).order_by('key') return ancestor @classmethod @@ -152,4 +166,6 @@ class Node(models.Model): obj, created = cls.objects.get_or_create( key='0', defaults={"key": '0', 'value': "ROOT"} ) + print(obj) return obj + diff --git a/apps/assets/serializers/asset.py b/apps/assets/serializers/asset.py index ac666e3a7..a0fdfab73 100644 --- a/apps/assets/serializers/asset.py +++ b/apps/assets/serializers/asset.py @@ -16,8 +16,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): """ 资产的数据结构 """ - nodes = serializers.SerializerMethodField() - class Meta: model = Asset list_serializer_class = BulkListSerializer @@ -31,10 +29,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): ]) return fields - @staticmethod - def get_nodes(obj): - return [n.id for n in obj.get_nodes_or_cache()] - class AssetGrantedSerializer(serializers.ModelSerializer): """ diff --git a/apps/assets/serializers/node.py b/apps/assets/serializers/node.py index 73639a100..56e01f742 100644 --- a/apps/assets/serializers/node.py +++ b/apps/assets/serializers/node.py @@ -9,7 +9,7 @@ from .asset import AssetGrantedSerializer __all__ = [ 'NodeSerializer', "NodeGrantedSerializer", "NodeAddChildrenSerializer", - "NodeAssetsSerializer", "NodeCurrentSerializer", + "NodeAssetsSerializer", ] @@ -64,11 +64,11 @@ class NodeSerializer(serializers.ModelSerializer): @staticmethod def get_parent(obj): - return obj.parent.id + return obj.parent.id if obj.is_node else obj.parent_id @staticmethod def get_assets_amount(obj): - return obj.get_all_assets().count() + return obj.get_all_assets().count() if obj.is_node else 0 def get_fields(self): fields = super().get_fields() @@ -77,12 +77,6 @@ class NodeSerializer(serializers.ModelSerializer): return fields -class NodeCurrentSerializer(NodeSerializer): - @staticmethod - def get_assets_amount(obj): - return obj.get_assets().count() - - class NodeAssetsSerializer(serializers.ModelSerializer): assets = serializers.PrimaryKeyRelatedField(many=True, queryset=Asset.objects.all()) diff --git a/apps/assets/signals_handler.py b/apps/assets/signals_handler.py index 16459c786..157c88012 100644 --- a/apps/assets/signals_handler.py +++ b/apps/assets/signals_handler.py @@ -64,7 +64,6 @@ def on_system_user_assets_change(sender, instance=None, **kwargs): @receiver(m2m_changed, sender=Asset.nodes.through) def on_asset_node_changed(sender, instance=None, **kwargs): if isinstance(instance, Asset): - instance.expire_nodes_cache() if kwargs['action'] == 'post_add': logger.debug("Asset node change signal received") nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) @@ -81,10 +80,6 @@ def on_asset_node_changed(sender, instance=None, **kwargs): def on_node_assets_changed(sender, instance=None, **kwargs): if isinstance(instance, Node): assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) - # 清理资产节点缓存 - for asset in assets: - asset.expire_nodes_cache() - if kwargs['action'] == 'post_add': logger.debug("Node assets change signal received") # 重新关联系统用户和资产的关系 diff --git a/apps/assets/templates/assets/_asset_list_modal.html b/apps/assets/templates/assets/_asset_list_modal.html index a0d96a7ef..faf569137 100644 --- a/apps/assets/templates/assets/_asset_list_modal.html +++ b/apps/assets/templates/assets/_asset_list_modal.html @@ -95,7 +95,7 @@ function initTree2() { }; var zNodes = []; - $.get("{% url 'api-assets:node-list' %}?show_current_asset=1", function(data, status){ + $.get("{% url 'api-assets:node-list' %}", function(data, status){ $.each(data, function (index, value) { value["pId"] = value["parent"]; {#value["open"] = true;#} diff --git a/apps/assets/templates/assets/asset_list.html b/apps/assets/templates/assets/asset_list.html index ab08f7b67..b5e53aaba 100644 --- a/apps/assets/templates/assets/asset_list.html +++ b/apps/assets/templates/assets/asset_list.html @@ -399,8 +399,7 @@ function initTree() { }; var zNodes = []; - var query_params = {'show_current_asset': getCookie('show_current_asset')}; - $.get("{% url 'api-assets:node-list' %}", query_params, function(data, status){ + $.get("{% url 'api-assets:node-list' %}", function(data, status){ $.each(data, function (index, value) { value["pId"] = value["parent"]; if (value["key"] === "0") { @@ -436,7 +435,7 @@ $(document).ready(function(){ initTable(); initTree(); - if(getCookie('show_current_asset') === 'yes'){ + if(getCookie('show_current_asset') === '1'){ $('#show_all_asset').css('display', 'inline-block'); } else{ @@ -564,7 +563,7 @@ $(document).ready(function(){ hideRMenu(); $(this).css('display', 'none'); $('#show_all_asset').css('display', 'inline-block'); - setCookie('show_current_asset', 'yes'); + setCookie('show_current_asset', '1'); location.reload(); }) .on('click', '.btn-show-all-asset', function(){ diff --git a/apps/assets/urls/api_urls.py b/apps/assets/urls/api_urls.py index 4429d0f24..ce622d648 100644 --- a/apps/assets/urls/api_urls.py +++ b/apps/assets/urls/api_urls.py @@ -23,8 +23,6 @@ urlpatterns = [ api.AssetRefreshHardwareApi.as_view(), name='asset-refresh'), url(r'^v1/assets/(?P[0-9a-zA-Z\-]{36})/alive/$', api.AssetAdminUserTestApi.as_view(), name='asset-alive-test'), - url(r'^v1/assets/user-assets/$', - api.UserAssetListView.as_view(), name='user-asset-list'), url(r'^v1/admin-user/(?P[0-9a-zA-Z\-]{36})/nodes/$', api.ReplaceNodesAdminUserApi.as_view(), name='replace-nodes-admin-user'), url(r'^v1/admin-user/(?P[0-9a-zA-Z\-]{36})/auth/$', @@ -35,17 +33,26 @@ urlpatterns = [ api.SystemUserPushApi.as_view(), name='system-user-push'), url(r'^v1/system-user/(?P[0-9a-zA-Z\-]{36})/connective/$', api.SystemUserTestConnectiveApi.as_view(), name='system-user-connective'), - url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/children/$', api.NodeChildrenApi.as_view(), name='node-children'), + url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/children/$', + api.NodeChildrenApi.as_view(), name='node-children'), url(r'^v1/nodes/children/$', api.NodeChildrenApi.as_view(), name='node-children-2'), - url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/children/add/$', api.NodeAddChildrenApi.as_view(), name='node-add-children'), - url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/assets/$', api.NodeAssetsApi.as_view(), name='node-assets'), - url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/assets/add/$', api.NodeAddAssetsApi.as_view(), name='node-add-assets'), - url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/assets/replace/$', api.NodeReplaceAssetsApi.as_view(), name='node-replace-assets'), - url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/assets/remove/$', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'), - url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/refresh-hardware-info/$', api.RefreshNodeHardwareInfoApi.as_view(), name='node-refresh-hardware-info'), - url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/test-connective/$', api.TestNodeConnectiveApi.as_view(), name='node-test-connective'), + url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/children/add/$', + api.NodeAddChildrenApi.as_view(), name='node-add-children'), + url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/assets/$', + api.NodeAssetsApi.as_view(), name='node-assets'), + url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/assets/add/$', + api.NodeAddAssetsApi.as_view(), name='node-add-assets'), + url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/assets/replace/$', + api.NodeReplaceAssetsApi.as_view(), name='node-replace-assets'), + url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/assets/remove/$', + api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'), + url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/refresh-hardware-info/$', + api.RefreshNodeHardwareInfoApi.as_view(), name='node-refresh-hardware-info'), + url(r'^v1/nodes/(?P[0-9a-zA-Z\-]{36})/test-connective/$', + api.TestNodeConnectiveApi.as_view(), name='node-test-connective'), - url(r'^v1/gateway/(?P[0-9a-zA-Z\-]{36})/test-connective/$', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'), + url(r'^v1/gateway/(?P[0-9a-zA-Z\-]{36})/test-connective/$', + api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'), ] urlpatterns += router.urls diff --git a/apps/common/api.py b/apps/common/api.py index 209d09747..63ed9723d 100644 --- a/apps/common/api.py +++ b/apps/common/api.py @@ -21,23 +21,13 @@ class MailTestingAPI(APIView): serializer = self.serializer_class(data=request.data) if serializer.is_valid(): email_host_user = serializer.validated_data["EMAIL_HOST_USER"] - kwargs = { - "host": serializer.validated_data["EMAIL_HOST"], - "port": serializer.validated_data["EMAIL_PORT"], - "username": serializer.validated_data["EMAIL_HOST_USER"], - "password": serializer.validated_data["EMAIL_HOST_PASSWORD"], - "use_ssl": serializer.validated_data["EMAIL_USE_SSL"], - "use_tls": serializer.validated_data["EMAIL_USE_TLS"] - } - connection = get_connection(timeout=5, **kwargs) + for k, v in serializer.validated_data.items(): + if k.startswith('EMAIL'): + setattr(settings, k, v) try: - connection.open() - except Exception as e: - return Response({"error": str(e)}, status=401) - - try: - send_mail("Test", "Test smtp setting", email_host_user, - [email_host_user], connection=connection) + subject = "Test" + message = "Test smtp setting" + send_mail(subject, message, email_host_user, [email_host_user]) except Exception as e: return Response({"error": str(e)}, status=401) diff --git a/apps/common/tasks.py b/apps/common/tasks.py index dec738921..bfb005511 100644 --- a/apps/common/tasks.py +++ b/apps/common/tasks.py @@ -2,6 +2,7 @@ from django.core.mail import send_mail from django.conf import settings from celery import shared_task from .utils import get_logger +from .models import Setting logger = get_logger(__file__) @@ -21,6 +22,10 @@ def send_mail_async(*args, **kwargs): Example: send_mail_sync.delay(subject, message, recipient_list, fail_silently=False, html_message=None) """ + configs = Setting.objects.filter(name__startswith='EMAIL') + for config in configs: + setattr(settings, config.name, config.cleaned_value) + if len(args) == 3: args = list(args) args[0] = settings.EMAIL_SUBJECT_PREFIX + args[0] diff --git a/apps/common/utils.py b/apps/common/utils.py index d73c094ec..deaeb5280 100644 --- a/apps/common/utils.py +++ b/apps/common/utils.py @@ -16,6 +16,7 @@ import calendar import threading from io import StringIO import uuid +from functools import wraps import paramiko import sshpubkeys @@ -395,3 +396,17 @@ class TeeObj: def close(self): self.file_obj.close() + +def with_cache(func): + cache = {} + key = "_{}.{}".format(func.__module__, func.__name__) + + @wraps(func) + def wrapper(*args, **kwargs): + cached = cache.get(key) + if cached: + return cached + res = func(*args, **kwargs) + cache[key] = res + return res + return wrapper diff --git a/apps/perms/api.py b/apps/perms/api.py index bd2fb1139..33a027064 100644 --- a/apps/perms/api.py +++ b/apps/perms/api.py @@ -41,11 +41,11 @@ class AssetPermissionViewSet(viewsets.ModelViewSet): asset = get_object_or_404(Asset, pk=asset_id) permissions = set(queryset.filter(assets=asset)) for node in asset.nodes.all(): - inherit_nodes.update(set(node.ancestor_with_self)) + inherit_nodes.update(set(node.get_ancestor(with_self=True))) elif node_id: node = get_object_or_404(Node, pk=node_id) permissions = set(queryset.filter(nodes=node)) - inherit_nodes = node.ancestor + inherit_nodes = node.get_ancestor() for n in inherit_nodes: _permissions = queryset.filter(nodes=n) @@ -70,7 +70,8 @@ class UserGrantedAssetsApi(ListAPIView): else: user = self.request.user - for k, v in AssetPermissionUtil.get_user_assets(user).items(): + util = AssetPermissionUtil(user) + for k, v in util.get_assets().items(): if k.is_unixlike(): system_users_granted = [s for s in v if s.protocol == 'ssh'] else: @@ -95,7 +96,8 @@ class UserGrantedNodesApi(ListAPIView): user = get_object_or_404(User, id=user_id) else: user = self.request.user - nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) + util = AssetPermissionUtil(user) + nodes = util.get_nodes_with_assets() return nodes.keys() def get_permissions(self): @@ -116,7 +118,8 @@ class UserGrantedNodesWithAssetsApi(ListAPIView): else: user = get_object_or_404(User, id=user_id) - nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) + util = AssetPermissionUtil(user) + nodes = util.get_nodes_with_assets() for node, _assets in nodes.items(): assets = _assets.keys() for k, v in _assets.items(): @@ -147,8 +150,9 @@ class UserGrantedNodeAssetsApi(ListAPIView): user = get_object_or_404(User, id=user_id) else: user = self.request.user + util = AssetPermissionUtil(user) node = get_object_or_404(Node, id=node_id) - nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) + nodes = util.get_nodes_with_assets() assets = nodes.get(node, []) for asset, system_users in assets.items(): asset.system_users_granted = system_users @@ -172,7 +176,8 @@ class UserGroupGrantedAssetsApi(ListAPIView): return queryset user_group = get_object_or_404(UserGroup, id=user_group_id) - assets = AssetPermissionUtil.get_user_group_assets(user_group) + util = AssetPermissionUtil(user_group) + assets = util.get_assets() for k, v in assets.items(): k.system_users_granted = v queryset.append(k) @@ -189,7 +194,8 @@ class UserGroupGrantedNodesApi(ListAPIView): if group_id: group = get_object_or_404(UserGroup, id=group_id) - nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(group) + util = AssetPermissionUtil(group) + nodes = util.get_nodes_with_assets() return nodes.keys() return queryset @@ -206,7 +212,8 @@ class UserGroupGrantedNodesWithAssetsApi(ListAPIView): return queryset user_group = get_object_or_404(UserGroup, id=user_group_id) - nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(user_group) + util = AssetPermissionUtil(user_group) + nodes = util.get_nodes_with_assets() for node, _assets in nodes.items(): assets = _assets.keys() for asset, system_users in _assets.items(): @@ -226,7 +233,8 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView): user_group = get_object_or_404(UserGroup, id=user_group_id) node = get_object_or_404(Node, id=node_id) - nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(user_group) + util = AssetPermissionUtil(user_group) + nodes = util.get_nodes_with_assets() assets = nodes.get(node, []) for asset, system_users in assets.items(): asset.system_users_granted = system_users @@ -246,7 +254,8 @@ class ValidateUserAssetPermissionView(APIView): asset = get_object_or_404(Asset, id=asset_id) system_user = get_object_or_404(SystemUser, id=system_id) - assets_granted = AssetPermissionUtil.get_user_assets(user) + util = AssetPermissionUtil(user) + assets_granted = util.get_assets() if system_user in assets_granted.get(asset, []): return Response({'msg': True}, status=200) else: diff --git a/apps/perms/utils.py b/apps/perms/utils.py index 4844cda06..475acfb68 100644 --- a/apps/perms/utils.py +++ b/apps/perms/utils.py @@ -1,359 +1,123 @@ # coding: utf-8 from __future__ import absolute_import, unicode_literals -import collections from collections import defaultdict -from django.utils import timezone -import copy +from django.db.models import Q -from common.utils import set_or_append_attr_bulk, get_logger +from common.utils import get_logger from .models import AssetPermission from .hands import Node logger = get_logger(__file__) -class Tree: - def __init__(self): - self.__all_nodes = list(Node.objects.all().prefetch_related('assets')) - self.__node_asset_map = defaultdict(set) - self.nodes = defaultdict(dict) - self.root = Node.root() - self.init_node_asset_map() +def get_user_permissions(user, include_group=True): + if include_group: + groups = user.groups.all() + arg = Q(users=user) | Q(user_groups=groups) + else: + arg = Q(users=user) + return AssetPermission.objects.all().valid().filter(arg) - def init_node_asset_map(self): - for node in self.__all_nodes: - assets = node.get_assets().values_list('id', flat=True) - for asset in assets: - self.__node_asset_map[str(asset)].add(node) - def add_asset(self, asset, system_users): - nodes = self.__node_asset_map.get(str(asset.id), []) - self.add_nodes(nodes) - for node in nodes: - self.nodes[node][asset].update(system_users) +def get_user_group_permissions(user_group): + return AssetPermission.objects.all().valid().filter( + user_groups=user_group + ) - def add_node(self, node): - if node in self.nodes: - return - else: - self.nodes[node] = defaultdict(set) - if node.key == self.root.key: - return - parent_key = ':'.join(node.key.split(':')[:-1]) - for n in self.__all_nodes: - if n.key == parent_key: - self.add_node(n) - break - def add_nodes(self, nodes): - for node in nodes: - self.add_node(node) +def get_asset_permissions(asset, include_node=True): + if include_node: + nodes = asset.get_all_nodes(flat=True) + arg = Q(assets=asset) | Q(nodes=nodes) + else: + arg = Q(assets=asset) + return AssetPermission.objects.all().valid().filter(arg) + + +def get_node_permissions(node): + return AssetPermission.objects.all().valid().filter(nodes=node) + + +def get_system_user_permissions(system_user): + return AssetPermission.objects.valid().all().filter( + system_users=system_user + ) class AssetPermissionUtil: - @staticmethod - def get_user_permissions(user): - return AssetPermission.objects.all().valid().filter(users=user) + get_permissions_map = { + "User": get_user_permissions, + "UserGroup": get_user_group_permissions, + "Asset": get_asset_permissions, + "Node": get_node_permissions, + "SystemUser": get_node_permissions, + } - @staticmethod - def get_user_group_permissions(user_group): - return AssetPermission.objects.all().valid().filter( - user_groups=user_group - ) + def __init__(self, obj): + self.object = obj + self._permissions = None - @staticmethod - def get_asset_permissions(asset): - return AssetPermission.objects.all().valid().filter( - assets=asset - ) + @property + def permissions(self): + if self._permissions: + return self._permissions + object_cls = self.object.__class__.__name__ + func = self.get_permissions_map[object_cls] + permissions = func(self.object) + self._permissions = permissions + return permissions - @staticmethod - def get_node_permissions(node): - return AssetPermission.objects.all().valid().filter(nodes=node) - - @staticmethod - def get_system_user_permissions(system_user): - return AssetPermission.objects.valid().all().filter( - system_users=system_user - ) - - @classmethod - def get_user_group_nodes(cls, group): + def get_nodes_direct(self): + """ + 返回用户/组授权规则直接关联的节点 + :return: {node1: set(system_user1,)} + """ nodes = defaultdict(set) - permissions = cls.get_user_group_permissions(group) + permissions = self.permissions.prefetch_related('nodes', 'system_users') for perm in permissions: - _nodes = perm.nodes.all() - _system_users = perm.system_users.all() - set_or_append_attr_bulk(_nodes, 'permission', perm.id) - for node in _nodes: - nodes[node].update(set(_system_users)) + for node in perm.nodes.all(): + nodes[node].update(perm.system_users.all()) return nodes - @classmethod - def get_user_group_assets_direct(cls, group): + def get_assets_direct(self): + """ + 返回用户授权规则直接关联的资产 + :return: {asset1: set(system_user1,)} + """ assets = defaultdict(set) - permissions = cls.get_user_group_permissions(group) + permissions = self.permissions.prefetch_related('assets', 'system_users') for perm in permissions: - _assets = perm.assets.all().valid() - _system_users = perm.system_users.all() - set_or_append_attr_bulk(_assets, 'permission', perm.id) - for asset in _assets: - assets[asset].update(set(_system_users)) + for asset in perm.assets.all().valid().prefetch_related('nodes'): + assets[asset].update(perm.system_users.all()) return assets - @classmethod - def get_user_group_nodes_assets(cls, group): - assets = defaultdict(set) - nodes = cls.get_user_group_nodes(group) - for node, _system_users in nodes.items(): - _assets = node.get_all_valid_assets() - set_or_append_attr_bulk(_assets, 'inherit_node', node.id) - set_or_append_attr_bulk(_assets, 'permission', getattr(node, 'permission', None)) - for asset in _assets: - assets[asset].update(set(_system_users)) - return assets - - @classmethod - def get_user_group_assets(cls, group): - assets = defaultdict(set) - _assets = cls.get_user_group_assets_direct(group) - _nodes_assets = cls.get_user_group_nodes_assets(group) - for asset, _system_users in _assets.items(): - assets[asset].update(set(_system_users)) - for asset, _system_users in _nodes_assets.items(): - assets[asset].update(set(_system_users)) - return assets - - @classmethod - def get_user_group_nodes_with_assets(cls, group): - """ - :param group: - :return: {node: {asset: set(su1, su2)}} - """ - _assets = cls.get_user_group_assets(group) - tree = Tree() - for asset, _system_users in _assets.items(): - _nodes = asset.get_nodes_or_cache() - tree.add_nodes(_nodes) - for node in _nodes: - tree.nodes[node][asset].update(_system_users) - return tree.nodes - - @classmethod - def get_user_assets_direct(cls, user): - assets = defaultdict(set) - permissions = list(cls.get_user_permissions(user)) - for perm in permissions: - _assets = perm.assets.all().valid() - _system_users = perm.system_users.all() - set_or_append_attr_bulk(_assets, 'permission', perm.id) - for asset in _assets: - assets[asset].update(set(_system_users)) - return assets - - @classmethod - def get_user_nodes_direct(cls, user): - nodes = defaultdict(set) - permissions = cls.get_user_permissions(user) - for perm in permissions: - _nodes = perm.nodes.all() - _system_users = perm.system_users.all() - set_or_append_attr_bulk(_nodes, 'permission', perm.id) - for node in _nodes: - nodes[node].update(set(_system_users)) - return nodes - - @classmethod - def get_user_nodes_inherit_group(cls, user): - nodes = defaultdict(set) - groups = user.groups.all() - for group in groups: - _nodes = cls.get_user_group_nodes(group) - for node, system_users in _nodes.items(): - nodes[node].update(set(system_users)) - return nodes - - @classmethod - def get_user_nodes(cls, user): - nodes = cls.get_user_nodes_direct(user) - nodes_inherit = cls.get_user_nodes_inherit_group(user) - for node, system_users in nodes_inherit.items(): - nodes[node].update(set(system_users)) - return nodes - - @classmethod - def get_user_nodes_assets_direct(cls, user): - assets = defaultdict(set) - nodes = cls.get_user_nodes_direct(user) - for node, _system_users in nodes.items(): - _assets = node.get_all_valid_assets() - set_or_append_attr_bulk(_assets, 'inherit_node', node.id) - set_or_append_attr_bulk(_assets, 'permission', getattr(node, 'permission', None)) - for asset in _assets: - assets[asset].update(set(_system_users)) - return assets - - @classmethod - def get_user_assets_inherit_group(cls, user): - assets = defaultdict(set) - for group in user.groups.all(): - _assets = cls.get_user_group_assets(group) - set_or_append_attr_bulk(_assets, 'inherit_group', group.id) - for asset, _system_users in _assets.items(): - assets[asset].update(_system_users) - return assets - - @classmethod - def get_user_assets(cls, user): - assets = defaultdict(set) - _assets_direct = cls.get_user_assets_direct(user) - _nodes_assets_direct = cls.get_user_nodes_assets_direct(user) - _assets_inherit_group = cls.get_user_assets_inherit_group(user) - for asset, _system_users in _assets_direct.items(): - assets[asset].update(_system_users) - for asset, _system_users in _nodes_assets_direct.items(): - assets[asset].update(_system_users) - for asset, _system_users in _assets_inherit_group.items(): - assets[asset].update(_system_users) - return assets - - @classmethod - def get_user_nodes_with_assets(cls, user): - """ - :param user: - :return: {node: {asset: set(su1, su2)}} - """ - tree = Tree() - _assets = cls.get_user_assets(user) - for asset, _system_users in _assets.items(): - tree.add_asset(asset, _system_users) - # _nodes = asset.get_nodes() - # tree.add_nodes(_nodes) - # for node in _nodes: - # tree.nodes[node][asset].update(_system_users) - return tree.nodes - - @classmethod - def get_system_user_assets(cls, system_user): - assets = set() - permissions = cls.get_system_user_permissions(system_user) - for perm in permissions: - assets.update(set(perm.assets.all().valid())) - nodes = perm.nodes.all() - for node in nodes: - assets.update(set(node.get_all_valid_assets())) - return assets - - @classmethod - def get_node_system_users(cls, node): - system_users = set() - permissions = cls.get_node_permissions(node) - for perm in permissions: - system_users.update(perm.system_users.all()) - return system_users - - -# Abandon -class NodePermissionUtil: - """ - - """ - - @staticmethod - def get_user_group_permissions(user_group): - return user_group.nodepermission_set.all() \ - .filter(is_active=True) \ - .filter(date_expired__gt=timezone.now()) - - @staticmethod - def get_system_user_permissions(system_user): - return system_user.nodepermission_set.all() \ - .filter(is_active=True) \ - .filter(date_expired__gt=timezone.now()) - - @classmethod - def get_user_group_nodes(cls, user_group): - """ - 获取用户组授权的node和系统用户 - :param user_group: - :return: {"node": set(systemuser1, systemuser2), ..} - """ - permissions = cls.get_user_group_permissions(user_group) - nodes_directed = collections.defaultdict(set) - - for perm in permissions: - nodes_directed[perm.node].add(perm.system_user) - - nodes = copy.deepcopy(nodes_directed) - for node, system_users in nodes_directed.items(): - for child in node.get_all_children_with_self(): - nodes[child].update(system_users) - return nodes - - @classmethod - def get_user_group_nodes_with_assets(cls, user_group): - """ - 获取用户组授权的节点和系统用户,节点下带有资产 - :param user_group: - :return: {"node": {"assets": "", "system_user": ""}, {}} - """ - nodes = cls.get_user_group_nodes(user_group) - nodes_with_assets = dict() + def get_assets(self): + assets = self.get_assets_direct() + nodes = self.get_nodes_direct() for node, system_users in nodes.items(): - nodes_with_assets[node] = { - 'assets': node.get_valid_assets(), - 'system_users': system_users - } - return nodes_with_assets - - @classmethod - def get_user_group_assets(cls, user_group): - assets = collections.defaultdict(set) - permissions = cls.get_user_group_permissions(user_group) - - for perm in permissions: - for asset in perm.node.get_all_assets(): - assets[asset].add(perm.system_user) + _assets = node.get_all_assets().valid().prefetch_related('nodes') + for asset in _assets: + if isinstance(asset, Node): + print(_assets) + assets[asset].update(system_users) return assets - @classmethod - def get_user_nodes(cls, user): - nodes = collections.defaultdict(set) - groups = user.groups.all() - for group in groups: - group_nodes = cls.get_user_group_nodes(group) - for node, system_users in group_nodes.items(): - nodes[node].update(system_users) + def get_nodes_with_assets(self): + """ + 返回节点并且包含资产 + {"node": {"assets": set("system_user")}} + :return: + """ + assets = self.get_assets() + nodes = defaultdict(dict) + for asset, system_users in assets.items(): + _nodes = asset.nodes.all() + for node in _nodes: + if asset in nodes[node]: + nodes[node][asset].update(system_users) + else: + nodes[node][asset] = system_users return nodes - @classmethod - def get_user_nodes_with_assets(cls, user): - nodes = cls.get_user_nodes(user) - nodes_with_assets = dict() - for node, system_users in nodes.items(): - nodes_with_assets[node] = { - 'assets': node.get_valid_assets(), - 'system_users': system_users - } - return nodes_with_assets - - @classmethod - def get_user_assets(cls, user): - assets = collections.defaultdict(set) - nodes_with_assets = cls.get_user_nodes_with_assets(user) - - for v in nodes_with_assets.values(): - for asset in v['assets']: - assets[asset].update(v['system_users']) - return assets - - @classmethod - def get_system_user_assets(cls, system_user): - assets = set() - permissions = cls.get_system_user_permissions(system_user) - - for perm in permissions: - assets.update(perm.node.get_all_assets()) - return assets diff --git a/utils/upgrade.sh b/utils/upgrade.sh index e1cdabb76..f700c9f7d 100644 --- a/utils/upgrade.sh +++ b/utils/upgrade.sh @@ -1,6 +1,6 @@ #!/bin/bash -if grep -q 'source ~/.autoenv/activate.sh' ~/.bashrc; then +if grep -q 'source /opt/autoenv/activate.sh' ~/.bashrc; then echo -e "\033[31m 正在自动载入 python 环境 \033[0m" else echo -e "\033[31m 不支持自动升级,请参考 http://docs.jumpserver.org/zh/docs/upgrade.html 手动升级 \033[0m" @@ -40,5 +40,6 @@ git pull && pip install -r requirements/requirements.txt && cd utils && sh make_ cd .. && ./jms start all -d echo -e "\033[31m 请检查jumpserver是否启动成功 \033[0m" echo -e "\033[31m 备份文件存放于$jumpserver_backup目录 \033[0m" +stty erase ^? exit 0