jumpserver/apps/assets/filters.py

140 lines
4.3 KiB
Python

# -*- coding: utf-8 -*-
#
from rest_framework.compat import coreapi, coreschema
from rest_framework import filters
from django.db.models import Q
from common.utils import dict_get_any, is_uuid, get_object_or_none
from .models import Node, Label
class AssetByNodeFilterBackend(filters.BaseFilterBackend):
fields = ['node', 'all']
def get_schema_fields(self, view):
return [
coreapi.Field(
name=field, location='query', required=False,
type='string', example='', description='', schema=None,
)
for field in self.fields
]
@staticmethod
def is_query_all(request):
query_all_arg = request.query_params.get('all')
show_current_asset_arg = request.query_params.get('show_current_asset')
query_all = query_all_arg == '1'
if show_current_asset_arg is not None:
query_all = show_current_asset_arg != '1'
return query_all
@staticmethod
def get_query_node(request):
node_id = dict_get_any(request.query_params, ['node', 'node_id'])
if not node_id:
return None, False
if is_uuid(node_id):
node = get_object_or_none(Node, id=node_id)
else:
node = get_object_or_none(Node, key=node_id)
return node, True
@staticmethod
def perform_query(pattern, queryset):
return queryset.filter(nodes__key__regex=pattern).distinct()
def filter_queryset(self, request, queryset, view):
node, has_query_arg = self.get_query_node(request)
if not has_query_arg:
return queryset
if node is None:
return queryset
query_all = self.is_query_all(request)
if query_all:
pattern = node.get_all_children_pattern(with_self=True)
else:
# pattern = node.get_children_key_pattern(with_self=True)
# 只显示当前节点下资产
pattern = r"^{}$".format(node.key)
return self.perform_query(pattern, queryset)
class LabelFilterBackend(filters.BaseFilterBackend):
sep = ':'
query_arg = 'label'
def get_schema_fields(self, view):
example = self.sep.join(['os', 'linux'])
return [
coreapi.Field(
name=self.query_arg, location='query', required=False,
type='string', example=example, description=''
)
]
def get_query_labels(self, request):
labels_query = request.query_params.getlist(self.query_arg)
if not labels_query:
return None
q = None
for kv in labels_query:
if '#' in kv:
self.sep = '#'
if self.sep not in kv:
continue
key, value = kv.strip().split(self.sep)[:2]
if not all([key, value]):
continue
if q:
q |= Q(name=key, value=value)
else:
q = Q(name=key, value=value)
if not q:
return []
labels = Label.objects.filter(q, is_active=True)\
.values_list('id', flat=True)
return labels
def filter_queryset(self, request, queryset, view):
labels = self.get_query_labels(request)
if labels is None:
return queryset
if len(labels) == 0:
return queryset.none()
for label in labels:
queryset = queryset.filter(labels=label)
return queryset
class AssetRelatedByNodeFilterBackend(AssetByNodeFilterBackend):
@staticmethod
def perform_query(pattern, queryset):
return queryset.filter(asset__nodes__key__regex=pattern).distinct()
class IpInFilterBackend(filters.BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
ips = request.query_params.get('ips')
if not ips:
return queryset
ip_list = [i.strip() for i in ips.split(',')]
queryset = queryset.filter(ip__in=ip_list)
return queryset
def get_schema_fields(self, view):
return [
coreapi.Field(
name='ips', location='query', required=False, type='string',
schema=coreschema.String(
title='ips',
description='ip in filter'
)
)
]