mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			182 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			182 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
from django.core.exceptions import ValidationError
 | 
						|
from django.shortcuts import get_object_or_404
 | 
						|
from django.utils.translation import gettext_lazy as _
 | 
						|
from rest_framework import status
 | 
						|
from rest_framework.decorators import action
 | 
						|
from rest_framework.response import Response
 | 
						|
 | 
						|
from common.api.generic import JMSModelViewSet
 | 
						|
from orgs.mixins.api import OrgBulkModelViewSet
 | 
						|
from orgs.mixins.models import OrgModelMixin
 | 
						|
from orgs.utils import current_org
 | 
						|
from rbac.models import ContentType
 | 
						|
from rbac.serializers import ContentTypeSerializer
 | 
						|
from . import serializers
 | 
						|
from .const import label_resource_types
 | 
						|
from .models import Label, LabeledResource
 | 
						|
 | 
						|
__all__ = ['LabelViewSet', 'ContentTypeViewSet']
 | 
						|
 | 
						|
 | 
						|
class ContentTypeViewSet(JMSModelViewSet):
 | 
						|
    serializer_class = ContentTypeSerializer
 | 
						|
    http_method_names = ['get', 'head', 'options']
 | 
						|
    rbac_perms = {
 | 
						|
        'resources': 'rbac.view_contenttype',
 | 
						|
    }
 | 
						|
    page_default_limit = None
 | 
						|
    can_labeled_content_type = []
 | 
						|
    model = ContentType
 | 
						|
 | 
						|
    def get_queryset(self):
 | 
						|
        return label_resource_types
 | 
						|
 | 
						|
    @action(methods=['GET'], detail=True, serializer_class=serializers.ContentTypeResourceSerializer)
 | 
						|
    def resources(self, request, *args, **kwargs):
 | 
						|
        self.page_default_limit = 100
 | 
						|
        content_type = self.get_object()
 | 
						|
        model = content_type.model_class()
 | 
						|
 | 
						|
        if issubclass(model, OrgModelMixin):
 | 
						|
            queryset = model.objects.filter(org_id=current_org.id)
 | 
						|
        elif hasattr(model, 'get_queryset'):
 | 
						|
            queryset = model.get_queryset()
 | 
						|
        else:
 | 
						|
            queryset = model.objects.all()
 | 
						|
 | 
						|
        keyword = request.query_params.get('search')
 | 
						|
        if keyword:
 | 
						|
            queryset = content_type.filter_queryset(queryset, keyword)
 | 
						|
        return self.get_paginated_response_from_queryset(queryset)
 | 
						|
 | 
						|
 | 
						|
class LabelContentTypeResourceViewSet(JMSModelViewSet):
 | 
						|
    serializer_class = serializers.ContentTypeResourceSerializer
 | 
						|
    rbac_perms = {
 | 
						|
        'default': 'labels.view_labeledresource',
 | 
						|
        'update': 'labels.change_labeledresource',
 | 
						|
    }
 | 
						|
    ordering_fields = ('res_type', 'date_created')
 | 
						|
 | 
						|
    def get_queryset(self):
 | 
						|
        label_pk = self.kwargs.get('label')
 | 
						|
        res_type = self.kwargs.get('res_type')
 | 
						|
        label = get_object_or_404(Label, pk=label_pk)
 | 
						|
        content_type = get_object_or_404(ContentType, id=res_type)
 | 
						|
        bound = self.request.query_params.get('bound', '1')
 | 
						|
        res_ids = LabeledResource.objects \
 | 
						|
            .filter(res_type=content_type, label=label) \
 | 
						|
            .values_list('res_id', flat=True)
 | 
						|
        res_ids = set(res_ids)
 | 
						|
        model = content_type.model_class()
 | 
						|
        if hasattr(model, 'get_queryset'):
 | 
						|
            queryset = model.get_queryset()
 | 
						|
        else:
 | 
						|
            queryset = model.objects.all()
 | 
						|
        if bound == '1':
 | 
						|
            queryset = queryset.filter(id__in=list(res_ids))
 | 
						|
        else:
 | 
						|
            queryset = queryset.exclude(id__in=list(res_ids))
 | 
						|
        keyword = self.request.query_params.get('search')
 | 
						|
        if keyword:
 | 
						|
            queryset = content_type.filter_queryset(queryset, keyword)
 | 
						|
        return queryset
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def validate_res_ids(content_type: ContentType, res_ids: list):
 | 
						|
        model_cls = content_type.model_class()
 | 
						|
        pk_field = model_cls._meta.pk
 | 
						|
        pk_python_type = pk_field.to_python
 | 
						|
        invalid_ids = []
 | 
						|
        for _id in res_ids:
 | 
						|
            try:
 | 
						|
                pk_python_type(_id)
 | 
						|
            except ValidationError:
 | 
						|
                invalid_ids.append(_id)
 | 
						|
        return invalid_ids
 | 
						|
 | 
						|
    def put(self, request, *args, **kwargs):
 | 
						|
        label_pk = self.kwargs.get('label')
 | 
						|
        res_type = self.kwargs.get('res_type')
 | 
						|
        content_type = get_object_or_404(ContentType, id=res_type)
 | 
						|
        label = get_object_or_404(Label, pk=label_pk)
 | 
						|
        res_ids = request.data.get('res_ids', [])
 | 
						|
 | 
						|
        invalid_ids = self.validate_res_ids(content_type, res_ids)
 | 
						|
        if invalid_ids:
 | 
						|
            error = f'{_("Invalid data")}: {", ".join(invalid_ids)}'
 | 
						|
            return Response({
 | 
						|
                "code": 'invalid_data', "detail": error,
 | 
						|
            }, status=status.HTTP_400_BAD_REQUEST)
 | 
						|
 | 
						|
        LabeledResource.objects \
 | 
						|
            .filter(res_type=content_type, label=label) \
 | 
						|
            .exclude(res_id__in=res_ids).delete()
 | 
						|
        resources = [
 | 
						|
            LabeledResource(res_type=content_type, res_id=res_id, label=label, org_id=current_org.id)
 | 
						|
            for res_id in res_ids
 | 
						|
        ]
 | 
						|
        LabeledResource.objects.bulk_create(resources, ignore_conflicts=True)
 | 
						|
        return Response({"total": len(res_ids)})
 | 
						|
 | 
						|
 | 
						|
class LabelViewSet(OrgBulkModelViewSet):
 | 
						|
    model = Label
 | 
						|
    filterset_fields = ("name", "value")
 | 
						|
    search_fields = filterset_fields
 | 
						|
    serializer_classes = {
 | 
						|
        'default': serializers.LabelSerializer,
 | 
						|
        'resource_types': ContentTypeSerializer,
 | 
						|
    }
 | 
						|
    rbac_perms = {
 | 
						|
        'resource_types': 'labels.view_label',
 | 
						|
        'keys': 'labels.view_label',
 | 
						|
    }
 | 
						|
 | 
						|
    @action(methods=['GET'], detail=False)
 | 
						|
    def keys(self, request, *args, **kwargs):
 | 
						|
        queryset = Label.objects.all()
 | 
						|
        keyword = request.query_params.get('search')
 | 
						|
        if keyword:
 | 
						|
            queryset = queryset.filter(name__icontains=keyword)
 | 
						|
        keys = queryset.values_list('name', flat=True).distinct()
 | 
						|
        return Response(keys)
 | 
						|
 | 
						|
 | 
						|
class LabeledResourceViewSet(OrgBulkModelViewSet):
 | 
						|
    model = LabeledResource
 | 
						|
    filterset_fields = ("label__name", "label__value", "res_type", "res_id", "label")
 | 
						|
    search_fields = []
 | 
						|
    serializer_classes = {
 | 
						|
        'default': serializers.LabeledResourceSerializer,
 | 
						|
    }
 | 
						|
    ordering_fields = ('res_type', 'date_created')
 | 
						|
 | 
						|
    def filter_search(self, queryset):
 | 
						|
        keyword = self.request.query_params.get('search')
 | 
						|
        if not keyword:
 | 
						|
            return queryset
 | 
						|
        keyword = keyword.strip().lower()
 | 
						|
        matched = []
 | 
						|
        offset = 0
 | 
						|
        limit = 10000
 | 
						|
        while True:
 | 
						|
            page = queryset[offset:offset + limit]
 | 
						|
            if not page:
 | 
						|
                break
 | 
						|
            offset += limit
 | 
						|
            for instance in page:
 | 
						|
                if keyword in str(instance.resource).lower():
 | 
						|
                    matched.append(instance.id)
 | 
						|
        return queryset.filter(id__in=matched)
 | 
						|
 | 
						|
    def get_queryset(self):
 | 
						|
        queryset = super().get_queryset()
 | 
						|
        queryset = queryset.order_by('res_type')
 | 
						|
        return queryset
 | 
						|
 | 
						|
    def filter_queryset(self, queryset):
 | 
						|
        queryset = super().filter_queryset(queryset)
 | 
						|
        queryset = self.filter_search(queryset)
 | 
						|
        return queryset
 |