mirror of https://github.com/jumpserver/jumpserver
145 lines
4.5 KiB
Python
145 lines
4.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
from django.conf import settings
|
|
from rest_framework.response import Response
|
|
from rest_framework import generics, filters
|
|
from rest_framework_bulk import BulkModelViewSet
|
|
|
|
from common.permissions import IsOrgAdminOrAppUser, NeedMFAVerify
|
|
from common.utils import get_object_or_none, get_logger
|
|
from common.mixins import CommonApiMixin
|
|
from ..backends import AssetUserManager
|
|
from ..models import Asset, Node, SystemUser
|
|
from .. import serializers
|
|
from ..tasks import (
|
|
test_asset_users_connectivity_manual, push_system_user_a_asset_manual
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
'AssetUserViewSet', 'AssetUserAuthInfoViewSet', 'AssetUserTaskCreateAPI',
|
|
]
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class AssetUserFilterBackend(filters.BaseFilterBackend):
|
|
def filter_queryset(self, request, queryset, view):
|
|
kwargs = {}
|
|
for field in view.filter_fields:
|
|
value = request.GET.get(field)
|
|
if not value:
|
|
continue
|
|
if field == "node_id":
|
|
value = get_object_or_none(Node, pk=value)
|
|
kwargs["node"] = value
|
|
continue
|
|
elif field == "asset_id":
|
|
field = "asset"
|
|
kwargs[field] = value
|
|
if kwargs:
|
|
queryset = queryset.filter(**kwargs)
|
|
logger.debug("Filter {}".format(kwargs))
|
|
return queryset
|
|
|
|
|
|
class AssetUserSearchBackend(filters.BaseFilterBackend):
|
|
def filter_queryset(self, request, queryset, view):
|
|
value = request.GET.get('search')
|
|
if not value:
|
|
return queryset
|
|
queryset = queryset.search(value)
|
|
return queryset
|
|
|
|
|
|
class AssetUserLatestFilterBackend(filters.BaseFilterBackend):
|
|
def filter_queryset(self, request, queryset, view):
|
|
latest = request.GET.get('latest') == '1'
|
|
if latest:
|
|
queryset = queryset.distinct()
|
|
return queryset
|
|
|
|
|
|
class AssetUserViewSet(CommonApiMixin, BulkModelViewSet):
|
|
serializer_classes = {
|
|
'default': serializers.AssetUserWriteSerializer,
|
|
'list': serializers.AssetUserReadSerializer,
|
|
'retrieve': serializers.AssetUserReadSerializer,
|
|
}
|
|
permission_classes = [IsOrgAdminOrAppUser]
|
|
filter_fields = [
|
|
"id", "ip", "hostname", "username",
|
|
"asset_id", "node_id",
|
|
"prefer", "prefer_id",
|
|
]
|
|
search_fields = ["ip", "hostname", "username"]
|
|
filter_backends = [
|
|
AssetUserFilterBackend, AssetUserSearchBackend,
|
|
AssetUserLatestFilterBackend,
|
|
]
|
|
|
|
def allow_bulk_destroy(self, qs, filtered):
|
|
return False
|
|
|
|
def get_object(self):
|
|
pk = self.kwargs.get("pk")
|
|
queryset = self.get_queryset()
|
|
obj = queryset.get(id=pk)
|
|
return obj
|
|
|
|
def get_exception_handler(self):
|
|
def handler(e, context):
|
|
return Response({"error": str(e)}, status=400)
|
|
return handler
|
|
|
|
def perform_destroy(self, instance):
|
|
manager = AssetUserManager()
|
|
manager.delete(instance)
|
|
|
|
def get_queryset(self):
|
|
manager = AssetUserManager()
|
|
queryset = manager.all()
|
|
return queryset
|
|
|
|
|
|
class AssetUserAuthInfoViewSet(AssetUserViewSet):
|
|
serializer_classes = {"default": serializers.AssetUserAuthInfoSerializer}
|
|
http_method_names = ['get']
|
|
permission_classes = [IsOrgAdminOrAppUser]
|
|
|
|
def get_permissions(self):
|
|
if settings.SECURITY_VIEW_AUTH_NEED_MFA:
|
|
self.permission_classes = [IsOrgAdminOrAppUser, NeedMFAVerify]
|
|
return super().get_permissions()
|
|
|
|
|
|
class AssetUserTaskCreateAPI(generics.CreateAPIView):
|
|
permission_classes = (IsOrgAdminOrAppUser,)
|
|
serializer_class = serializers.AssetUserTaskSerializer
|
|
filter_backends = AssetUserViewSet.filter_backends
|
|
filter_fields = AssetUserViewSet.filter_fields
|
|
|
|
def get_asset_users(self):
|
|
manager = AssetUserManager()
|
|
queryset = manager.all()
|
|
for cls in self.filter_backends:
|
|
queryset = cls().filter_queryset(self.request, queryset, self)
|
|
return list(queryset)
|
|
|
|
def perform_create(self, serializer):
|
|
asset_users = self.get_asset_users()
|
|
# action = serializer.validated_data["action"]
|
|
# only this
|
|
# if action == "test":
|
|
task = test_asset_users_connectivity_manual.delay(asset_users)
|
|
data = getattr(serializer, '_data', {})
|
|
data["task"] = task.id
|
|
setattr(serializer, '_data', data)
|
|
return task
|
|
|
|
def get_exception_handler(self):
|
|
def handler(e, context):
|
|
return Response({"error": str(e)}, status=400)
|
|
return handler
|