mirror of https://github.com/jumpserver/jumpserver
96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
import coreapi
|
|
from rest_framework import filters
|
|
from rest_framework.fields import DateTimeField
|
|
from rest_framework.serializers import ValidationError
|
|
from django.core.cache import cache
|
|
import logging
|
|
|
|
from . import const
|
|
|
|
__all__ = ["DatetimeRangeFilter", "IDSpmFilter", "CustomFilter"]
|
|
|
|
|
|
class DatetimeRangeFilter(filters.BaseFilterBackend):
|
|
def filter_queryset(self, request, queryset, view):
|
|
if not hasattr(view, 'date_range_filter_fields'):
|
|
return queryset
|
|
try:
|
|
fields = dict(view.date_range_filter_fields)
|
|
except ValueError:
|
|
msg = "View {} datetime_filter_fields set is error".format(view.name)
|
|
logging.error(msg)
|
|
return queryset
|
|
kwargs = {}
|
|
for attr, date_range_keyword in fields.items():
|
|
if len(date_range_keyword) != 2:
|
|
continue
|
|
for i, v in enumerate(date_range_keyword):
|
|
value = request.query_params.get(v)
|
|
if not value:
|
|
continue
|
|
try:
|
|
field = DateTimeField()
|
|
value = field.to_internal_value(value)
|
|
if i == 0:
|
|
lookup = "__gte"
|
|
else:
|
|
lookup = "__lte"
|
|
kwargs[attr+lookup] = value
|
|
except ValidationError as e:
|
|
print(e)
|
|
continue
|
|
if kwargs:
|
|
queryset = queryset.filter(**kwargs)
|
|
return queryset
|
|
|
|
|
|
class IDSpmFilter(filters.BaseFilterBackend):
|
|
def get_schema_fields(self, view):
|
|
return [
|
|
coreapi.Field(
|
|
name='spm', location='query', required=False,
|
|
type='string', example='',
|
|
description='Pre post objects id get spm id, then using filter'
|
|
)
|
|
]
|
|
|
|
def filter_queryset(self, request, queryset, view):
|
|
spm = request.query_params.get('spm')
|
|
if not spm:
|
|
return queryset
|
|
cache_key = const.KEY_CACHE_RESOURCES_ID.format(spm)
|
|
resources_id = cache.get(cache_key)
|
|
if not resources_id or not isinstance(resources_id, list):
|
|
queryset = queryset.none()
|
|
return queryset
|
|
queryset = queryset.filter(id__in=resources_id)
|
|
return queryset
|
|
|
|
|
|
class CustomFilter(filters.BaseFilterBackend):
|
|
|
|
def get_schema_fields(self, view):
|
|
fields = []
|
|
defaults = dict(
|
|
location='query', required=False,
|
|
type='string', example='',
|
|
description=''
|
|
)
|
|
if not hasattr(view, 'custom_filter_fields'):
|
|
return []
|
|
|
|
for field in view.custom_filter_fields:
|
|
if isinstance(field, str):
|
|
defaults['name'] = field
|
|
elif isinstance(field, dict):
|
|
defaults.update(field)
|
|
else:
|
|
continue
|
|
fields.append(coreapi.Field(**defaults))
|
|
return fields
|
|
|
|
def filter_queryset(self, request, queryset, view):
|
|
return queryset
|