You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
jumpserver/apps/terminal/api/component/endpoint.py

98 lines
3.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from django.shortcuts import get_object_or_404
from django.utils.translation import gettext_lazy as _
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from assets.models import Asset
from authentication.permissions import IsValidUserOrConnectionToken
from common.api import JMSBulkModelViewSet
from orgs.utils import tmp_to_root_org
from terminal import serializers
from terminal.models import Session, Endpoint, EndpointRule
__all__ = ['EndpointViewSet', 'EndpointRuleViewSet']
class SmartEndpointViewMixin:
get_serializer: callable
request: Request
# View 处理过程中用的属性
target_instance: None
target_protocol: None
@action(methods=['get'], detail=False, permission_classes=[IsValidUserOrConnectionToken])
@tmp_to_root_org()
def smart(self, request, *args, **kwargs):
self.target_instance = self.get_target_instance()
self.target_protocol = self.get_target_protocol()
if not self.target_protocol:
error = _('Not found protocol query params')
return Response(data={'error': error}, status=status.HTTP_404_NOT_FOUND)
endpoint = self.match_endpoint()
serializer = self.get_serializer(endpoint)
return Response(serializer.data)
def match_endpoint(self):
endpoint = self.match_endpoint_by_label()
if not endpoint:
endpoint = self.match_endpoint_by_target_ip()
return endpoint
def match_endpoint_by_label(self):
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request)
def match_endpoint_by_target_ip(self):
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试
if not target_ip and callable(getattr(self.target_instance, 'get_target_ip', None)):
target_ip = self.target_instance.get_target_ip()
endpoint = EndpointRule.match_endpoint(
self.target_instance, target_ip, self.target_protocol, self.request
)
return endpoint
def get_target_instance(self):
request = self.request
asset_id = request.GET.get('asset_id')
session_id = request.GET.get('session_id')
token_id = request.GET.get('token')
if token_id:
from authentication.models import ConnectionToken
token = ConnectionToken.objects.filter(id=token_id).first()
if token and token.asset:
asset_id = token.asset.id
if asset_id:
pk, model = asset_id, Asset
elif session_id:
pk, model = session_id, Session
else:
pk, model = None, None
if not pk or not model:
return None
with tmp_to_root_org():
instance = get_object_or_404(model, pk=pk)
return instance
def get_target_protocol(self):
protocol = None
if not protocol:
protocol = self.request.GET.get('protocol')
return protocol
class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet):
filterset_fields = ('name', 'host')
search_fields = filterset_fields
serializer_class = serializers.EndpointSerializer
queryset = Endpoint.objects.all()
class EndpointRuleViewSet(JMSBulkModelViewSet):
filterset_fields = ('name',)
search_fields = filterset_fields
serializer_class = serializers.EndpointRuleSerializer
queryset = EndpointRule.objects.all()