mirror of https://github.com/jumpserver/jumpserver
154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
# ~*~ coding: utf-8 ~*~
|
|
|
|
from rest_framework.generics import get_object_or_404
|
|
from rest_framework.response import Response
|
|
|
|
from assets.locks import NodeAddChildrenLock
|
|
from common.tree import TreeNodeSerializer
|
|
from common.utils import get_logger
|
|
from orgs.mixins import generics
|
|
from orgs.utils import current_org
|
|
from .mixin import SerializeToTreeNodeMixin
|
|
from .. import serializers
|
|
from ..const import AllTypes
|
|
from ..models import Node, Platform, Asset
|
|
|
|
logger = get_logger(__file__)
|
|
__all__ = [
|
|
'NodeChildrenApi',
|
|
'NodeChildrenAsTreeApi',
|
|
'CategoryTreeApi',
|
|
]
|
|
|
|
|
|
class NodeChildrenApi(generics.ListCreateAPIView):
|
|
"""
|
|
节点的增删改查
|
|
"""
|
|
serializer_class = serializers.NodeSerializer
|
|
search_fields = ('value',)
|
|
|
|
instance = None
|
|
is_initial = False
|
|
|
|
def initial(self, request, *args, **kwargs):
|
|
self.instance = self.get_object()
|
|
return super().initial(request, *args, **kwargs)
|
|
|
|
def perform_create(self, serializer):
|
|
with NodeAddChildrenLock(self.instance):
|
|
data = serializer.validated_data
|
|
_id = data.get("id")
|
|
value = data.get("value")
|
|
if not value:
|
|
value = self.instance.get_next_child_preset_name()
|
|
node = self.instance.create_child(value=value, _id=_id)
|
|
# 避免查询 full value
|
|
node._full_value = node.value
|
|
serializer.instance = node
|
|
|
|
def get_object(self):
|
|
pk = self.kwargs.get('pk') or self.request.query_params.get('id')
|
|
key = self.request.query_params.get("key")
|
|
|
|
if not pk and not key:
|
|
self.is_initial = True
|
|
if current_org.is_root():
|
|
node = None
|
|
else:
|
|
node = Node.org_root()
|
|
return node
|
|
if pk:
|
|
node = get_object_or_404(Node, pk=pk)
|
|
else:
|
|
node = get_object_or_404(Node, key=key)
|
|
return node
|
|
|
|
def get_org_root_queryset(self, query_all):
|
|
if query_all:
|
|
return Node.objects.all()
|
|
else:
|
|
return Node.org_root_nodes()
|
|
|
|
def get_queryset(self):
|
|
query_all = self.request.query_params.get("all", "0") == "all"
|
|
|
|
if self.is_initial and current_org.is_root():
|
|
return self.get_org_root_queryset(query_all)
|
|
|
|
if self.is_initial:
|
|
with_self = True
|
|
else:
|
|
with_self = False
|
|
|
|
if not self.instance:
|
|
return Node.objects.none()
|
|
|
|
if query_all:
|
|
queryset = self.instance.get_all_children(with_self=with_self)
|
|
else:
|
|
queryset = self.instance.get_children(with_self=with_self)
|
|
return queryset
|
|
|
|
|
|
class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
|
|
"""
|
|
节点子节点作为树返回,
|
|
[
|
|
{
|
|
"id": "",
|
|
"name": "",
|
|
"pId": "",
|
|
"meta": ""
|
|
}
|
|
]
|
|
|
|
"""
|
|
model = Node
|
|
|
|
def filter_queryset(self, queryset):
|
|
if not self.request.GET.get('search'):
|
|
return queryset
|
|
queryset = super().filter_queryset(queryset)
|
|
queryset = self.model.get_ancestor_queryset(queryset)
|
|
return queryset
|
|
|
|
def list(self, request, *args, **kwargs):
|
|
nodes = self.filter_queryset(self.get_queryset()).order_by('value')
|
|
nodes = self.serialize_nodes(nodes, with_asset_amount=True)
|
|
assets = self.get_assets_as_node()
|
|
data = [*nodes, *assets]
|
|
return Response(data=data)
|
|
|
|
def get_assets_as_node(self):
|
|
include_assets = self.request.query_params.get('assets', '0') == '1'
|
|
if not self.instance or not include_assets:
|
|
return []
|
|
assets = self.instance.get_assets_for_tree()
|
|
return self.serialize_assets(assets, self.instance.key)
|
|
|
|
|
|
class CategoryTreeApi(SerializeToTreeNodeMixin, generics.ListAPIView):
|
|
serializer_class = TreeNodeSerializer
|
|
rbac_perms = {
|
|
'GET': 'assets.view_asset',
|
|
'list': 'assets.view_asset',
|
|
}
|
|
|
|
def get_assets(self):
|
|
key = self.request.query_params.get('key')
|
|
platform = Platform.objects.filter(id=key).first()
|
|
if not platform:
|
|
return []
|
|
assets = Asset.objects.filter(platform=platform).prefetch_related('platform')
|
|
return self.serialize_assets(assets, key)
|
|
|
|
def list(self, request, *args, **kwargs):
|
|
include_asset = self.request.query_params.get('assets', '0') == '1'
|
|
|
|
if include_asset and self.request.query_params.get('key'):
|
|
nodes = self.get_assets()
|
|
else:
|
|
nodes = AllTypes.to_tree_nodes(include_asset)
|
|
return Response(data=nodes)
|