perf: 修改 perms

pull/8873/head
ibuler 2022-08-22 18:32:33 +08:00
parent f0c9c2b1ad
commit 09607a1885
27 changed files with 67 additions and 140 deletions

View File

@ -1,4 +1,4 @@
from .common import * from .asset import *
from .host import * from .host import *
from .database import * from .database import *
from .permission import * from .permission import *

View File

@ -2,18 +2,20 @@
# #
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
import django_filters
from common.drf.filters import BaseFilterSet
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from common.mixins.api import SuggestionMixin from common.mixins.api import SuggestionMixin
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from orgs.mixins import generics from orgs.mixins import generics
from assets.api import FilterAssetByNodeMixin
from assets.models import Asset, Node, Gateway from assets.models import Asset, Node, Gateway
from assets import serializers from assets import serializers
from assets.tasks import ( from assets.tasks import (
update_assets_hardware_info_manual, test_assets_connectivity_manual, update_assets_hardware_info_manual, test_assets_connectivity_manual,
) )
from assets.filters import FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend from assets.filters import NodeFilterBackend, LabelFilterBackend, IpInFilterBackend
from ..mixin import NodeFilterMixin
logger = get_logger(__file__) logger = get_logger(__file__)
__all__ = [ __all__ = [
@ -21,17 +23,21 @@ __all__ = [
] ]
class AssetViewSet(SuggestionMixin, FilterAssetByNodeMixin, OrgBulkModelViewSet): class AssetFilterSet(BaseFilterSet):
type = django_filters.CharFilter(field_name='platform__type', lookup_expr='exact')
category = django_filters.CharFilter(field_name='platform__category', lookup_expr='exact')
class Meta:
model = Asset
fields = ['name', 'ip', 'is_active', 'type', 'category']
class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
""" """
API endpoint that allows Asset to be viewed or edited. API endpoint that allows Asset to be viewed or edited.
""" """
model = Asset model = Asset
filterset_fields = { filterset_class = AssetFilterSet
'name': ['exact'],
'ip': ['exact'],
'is_active': ['exact'],
'protocols': ['exact', 'icontains']
}
search_fields = ("name", "ip") search_fields = ("name", "ip")
ordering_fields = ("name", "ip", "port") ordering_fields = ("name", "ip", "port")
ordering = ('name', ) ordering = ('name', )
@ -47,9 +53,9 @@ class AssetViewSet(SuggestionMixin, FilterAssetByNodeMixin, OrgBulkModelViewSet)
('gateways', 'assets.view_gateway') ('gateways', 'assets.view_gateway')
) )
extra_filter_backends = [ extra_filter_backends = [
FilterAssetByNodeFilterBackend,
LabelFilterBackend, LabelFilterBackend,
IpInFilterBackend, IpInFilterBackend,
NodeFilterBackend
] ]
def set_assets_node(self, assets): def set_assets_node(self, assets):

View File

@ -1,7 +1,7 @@
from assets.models import Database from assets.models import Database
from assets.serializers import DatabaseSerializer from assets.serializers import DatabaseSerializer
from .common import AssetViewSet from .asset import AssetViewSet
__all__ = ['DatabaseViewSet'] __all__ = ['DatabaseViewSet']

View File

@ -1,7 +1,7 @@
from assets.models import Host from assets.models import Host
from assets.serializers import HostSerializer from assets.serializers import HostSerializer
from .common import AssetViewSet from .asset import AssetViewSet
__all__ = ['HostViewSet'] __all__ = ['HostViewSet']

View File

@ -1,10 +1,10 @@
from typing import List from typing import List
from rest_framework.request import Request
from common.utils.common import timeit from common.utils import lazyproperty, timeit
from assets.models import Node, Asset from assets.models import Node, Asset
from assets.pagination import NodeAssetTreePagination from assets.pagination import NodeAssetTreePagination
from common.utils import lazyproperty from assets.utils import get_node_from_request, is_query_node_all_assets
from assets.utils import get_node, is_query_node_all_assets
class SerializeToTreeNodeMixin: class SerializeToTreeNodeMixin:
@ -80,8 +80,9 @@ class SerializeToTreeNodeMixin:
return data return data
class FilterAssetByNodeMixin: class NodeFilterMixin:
pagination_class = NodeAssetTreePagination pagination_class = NodeAssetTreePagination
request: Request
@lazyproperty @lazyproperty
def is_query_node_all_assets(self): def is_query_node_all_assets(self):
@ -89,4 +90,4 @@ class FilterAssetByNodeMixin:
@lazyproperty @lazyproperty
def node(self): def node(self):
return get_node(self.request) return get_node_from_request(self.request)

View File

@ -6,7 +6,7 @@ from rest_framework import filters
from django.db.models import Q from django.db.models import Q
from .models import Label from .models import Label
from assets.utils import is_query_node_all_assets, get_node from assets.utils import is_query_node_all_assets, get_node_from_request
class AssetByNodeFilterBackend(filters.BaseFilterBackend): class AssetByNodeFilterBackend(filters.BaseFilterBackend):
@ -31,7 +31,7 @@ class AssetByNodeFilterBackend(filters.BaseFilterBackend):
return queryset.filter(nodes__key=node.key).distinct() return queryset.filter(nodes__key=node.key).distinct()
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
node = get_node(request) node = get_node_from_request(request)
if node is None: if node is None:
return queryset return queryset
@ -42,9 +42,9 @@ class AssetByNodeFilterBackend(filters.BaseFilterBackend):
return self.filter_node_related_direct(queryset, node) return self.filter_node_related_direct(queryset, node)
class FilterAssetByNodeFilterBackend(filters.BaseFilterBackend): class NodeFilterBackend(filters.BaseFilterBackend):
""" """
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用 需要与 `assets.api.mixin.NodeFilterMixin` 配合使用
""" """
fields = ['node', 'all'] fields = ['node', 'all']
@ -58,10 +58,11 @@ class FilterAssetByNodeFilterBackend(filters.BaseFilterBackend):
] ]
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
node = view.node node = get_node_from_request(request)
if node is None: if node is None:
return queryset return queryset
query_all = view.is_query_node_all_assets
query_all = is_query_node_all_assets(request)
if query_all: if query_all:
return queryset.filter( return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes__key__istartswith=f'{node.key}:') |
@ -94,6 +95,9 @@ class LabelFilterBackend(filters.BaseFilterBackend):
for kv in labels_query: for kv in labels_query:
if '#' in kv: if '#' in kv:
self.sep = '#' self.sep = '#'
break
for kv in labels_query:
if self.sep not in kv: if self.sep not in kv:
continue continue
key, value = kv.strip().split(self.sep)[:2] key, value = kv.strip().split(self.sep)[:2]

View File

@ -9,6 +9,7 @@ from functools import reduce
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from common.utils import lazyproperty
from orgs.mixins.models import OrgManager, JMSOrgBaseModel from orgs.mixins.models import OrgManager, JMSOrgBaseModel
from ..platform import Platform from ..platform import Platform
from ..base import AbsConnectivity from ..base import AbsConnectivity
@ -110,11 +111,11 @@ class Asset(AbsConnectivity, NodesRelationMixin, JMSOrgBaseModel):
names.append(n.name + ':' + n.value) names.append(n.name + ':' + n.value)
return names return names
@property @lazyproperty
def type(self): def type(self):
return self.platform.type return self.platform.type
@property @lazyproperty
def category(self): def category(self):
return self.platform.category return self.platform.category

View File

@ -8,6 +8,9 @@ logger = get_logger(__name__)
class AssetPaginationBase(LimitOffsetPagination): class AssetPaginationBase(LimitOffsetPagination):
_request = None
_view = None
_user = None
def init_attrs(self, queryset, request: Request, view=None): def init_attrs(self, queryset, request: Request, view=None):
self._request = request self._request = request
@ -28,7 +31,8 @@ class AssetPaginationBase(LimitOffsetPagination):
} }
for k, v in self._request.query_params.items(): for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None: if k not in exclude_query_params and v is not None:
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}') logger.warn(f'Not hit node.assets_amount because find a unknown query_param '
f'`{k}` -> {self._request.get_full_path()}')
return super().get_count(queryset) return super().get_count(queryset)
node_assets_count = self.get_count_from_nodes(queryset) node_assets_count = self.get_count_from_nodes(queryset)
if node_assets_count is None: if node_assets_count is None:
@ -49,4 +53,4 @@ class NodeAssetTreePagination(AssetPaginationBase):
node = Node.org_root() node = Node.org_root()
if node: if node:
logger.debug(f'Hit node assets_amount cache: [{node.assets_amount}]') logger.debug(f'Hit node assets_amount cache: [{node.assets_amount}]')
return node.assets_amount return node.assets_amount

View File

@ -2,6 +2,7 @@
# #
from rest_framework import serializers from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.db.models import F
from common.drf.serializers import JMSWritableNestedModelSerializer from common.drf.serializers import JMSWritableNestedModelSerializer
from common.drf.fields import ChoiceDisplayField from common.drf.fields import ChoiceDisplayField
@ -107,7 +108,9 @@ class AssetSerializer(JMSWritableNestedModelSerializer):
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('domain', 'platform', 'protocols') queryset = queryset.prefetch_related('domain', 'platform', 'protocols')\
.annotate(category=F("platform__category"))\
.annotate(type=F("platform__type"))
queryset = queryset.prefetch_related('nodes', 'labels') queryset = queryset.prefetch_related('nodes', 'labels')
return queryset return queryset

View File

@ -53,7 +53,7 @@ def is_query_node_all_assets(request):
return is_true(query_all_arg) return is_true(query_all_arg)
def get_node(request): def get_node_from_request(request):
node_id = dict_get_any(request.query_params, ['node', 'node_id']) node_id = dict_get_any(request.query_params, ['node', 'node_id'])
if not node_id: if not node_id:
return None return None

View File

@ -1,4 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from .asset import * from .user_permission import *
from .asset_permission import *
from .asset_permission_relation import *
from .user_group_permission import *

View File

@ -1,4 +0,0 @@
from .user_permission import *
from .asset_permission import *
from .asset_permission_relation import *
from .user_group_permission import *

View File

@ -1,6 +0,0 @@
# -*- coding: utf-8 -*-
#
from .common import *
from .user_permission_nodes import *
from .user_permission_assets import *
from .user_permission_nodes_with_assets import *

View File

@ -1 +0,0 @@
from .views import *

View File

@ -1,95 +0,0 @@
from django.db.models import Q
from common.utils import get_object_or_none
from orgs.mixins.api import OrgBulkModelViewSet
from assets.models import SystemUser
from users.models import User, UserGroup
__all__ = ['BasePermissionViewSet']
class BasePermissionViewSet(OrgBulkModelViewSet):
custom_filter_fields = [
'user_id', 'username', 'system_user_id', 'system_user',
'user_group_id', 'user_group'
]
def filter_valid(self, queryset):
valid_query = self.request.query_params.get('is_valid', None)
if valid_query is None:
return queryset
invalid = valid_query in ['0', 'N', 'false', 'False']
if invalid:
queryset = queryset.invalid()
else:
queryset = queryset.valid()
return queryset
def is_query_all(self):
query_all = self.request.query_params.get('all', '1') == '1'
return query_all
def filter_user(self, queryset):
user_id = self.request.query_params.get('user_id')
username = self.request.query_params.get('username')
if user_id:
user = get_object_or_none(User, pk=user_id)
elif username:
user = get_object_or_none(User, username=username)
else:
return queryset
if not user:
return queryset.none()
if not self.is_query_all():
queryset = queryset.filter(users=user)
return queryset
groups = list(user.groups.all().values_list('id', flat=True))
queryset = queryset.filter(
Q(users=user) | Q(user_groups__in=groups)
).distinct()
return queryset
def filter_keyword(self, queryset):
keyword = self.request.query_params.get('search')
if not keyword:
return queryset
queryset = queryset.filter(name__icontains=keyword)
return queryset
def filter_system_user(self, queryset):
system_user_id = self.request.query_params.get('system_user_id')
system_user_name = self.request.query_params.get('system_user')
if system_user_id:
system_user = get_object_or_none(SystemUser, pk=system_user_id)
elif system_user_name:
system_user = get_object_or_none(SystemUser, name=system_user_name)
else:
return queryset
if not system_user:
return queryset.none()
queryset = queryset.filter(system_users=system_user)
return queryset
def filter_user_group(self, queryset):
user_group_id = self.request.query_params.get('user_group_id')
user_group_name = self.request.query_params.get('user_group')
if user_group_id:
group = get_object_or_none(UserGroup, pk=user_group_id)
elif user_group_name:
group = get_object_or_none(UserGroup, name=user_group_name)
else:
return queryset
if not group:
return queryset.none()
queryset = queryset.filter(user_groups=group)
return queryset
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_valid(queryset)
queryset = self.filter_user(queryset)
queryset = self.filter_system_user(queryset)
queryset = self.filter_user_group(queryset)
queryset = self.filter_keyword(queryset)
queryset = queryset.distinct()
return queryset

View File

@ -9,7 +9,7 @@ from rest_framework.response import Response
from common.utils import lazyproperty from common.utils import lazyproperty
from perms.models import AssetPermission from perms.models import AssetPermission
from assets.models import Asset, Node from assets.models import Asset, Node
from perms.api.asset import user_permission as uapi from . import user_permission as uapi
from perms import serializers from perms import serializers
from perms.utils.asset.permission import get_asset_system_user_ids_with_actions_by_group from perms.utils.asset.permission import get_asset_system_user_ids_with_actions_by_group
from assets.api.mixin import SerializeToTreeNodeMixin from assets.api.mixin import SerializeToTreeNodeMixin

View File

@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
#
from .common import *
from .nodes import *
from .assets import *
from .nodes_with_assets import *

View File

@ -0,0 +1 @@
from .api import *

View File

@ -9,6 +9,8 @@ logger = get_logger(__name__)
class GrantedAssetPaginationBase(AssetPaginationBase): class GrantedAssetPaginationBase(AssetPaginationBase):
_user: object
def init_attrs(self, queryset, request: Request, view=None): def init_attrs(self, queryset, request: Request, view=None):
super().init_attrs(queryset, request, view) super().init_attrs(queryset, request, view)
self._user = view.user self._user = view.user
@ -18,10 +20,12 @@ class NodeGrantedAssetPagination(GrantedAssetPaginationBase):
def get_count_from_nodes(self, queryset): def get_count_from_nodes(self, queryset):
node = getattr(self._view, 'pagination_node', None) node = getattr(self._view, 'pagination_node', None)
if node: if node:
logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}') logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> '
f'{self._request.get_full_path()}')
return node.assets_amount return node.assets_amount
else: else:
logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} not has `pagination_node` -> {self._request.get_full_path()}') logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} '
f'not has `pagination_node` -> {self._request.get_full_path()}')
return None return None