fix:绑定的端点Default下载RDP文件中的地址是空的

pull/12577/head^2
wangruidong 2024-01-18 18:35:32 +08:00 committed by Bryan
parent cd0348cca1
commit 3853d0bcc6
3 changed files with 18 additions and 12 deletions

View File

@ -205,7 +205,7 @@ class RDPFileClientProtocolURLMixin:
return data
def get_smart_endpoint(self, protocol, asset=None):
endpoint = Endpoint.match_by_instance_label(asset, protocol)
endpoint = Endpoint.match_by_instance_label(asset, protocol, self.request)
if not endpoint:
target_ip = asset.get_target_ip() if asset else ''
endpoint = EndpointRule.match_endpoint(

View File

@ -42,7 +42,7 @@ class SmartEndpointViewMixin:
return endpoint
def match_endpoint_by_label(self):
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request)
def match_endpoint_by_target_ip(self):
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试

View File

@ -75,7 +75,20 @@ class Endpoint(JMSBaseModel):
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol):
def handle_endpoint_host(cls, endpoint, request=None):
if not endpoint.host and request:
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol, request=None):
from assets.models import Asset
from terminal.models import Session
if isinstance(instance, Session):
@ -88,6 +101,7 @@ class Endpoint(JMSBaseModel):
endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated')
for endpoint in endpoints:
if endpoint.is_valid_for(instance, protocol):
endpoint = cls.handle_endpoint_host(endpoint, request)
return endpoint
@ -130,13 +144,5 @@ class EndpointRule(JMSBaseModel):
endpoint = endpoint_rule.endpoint
else:
endpoint = Endpoint.get_or_create_default(request)
if not endpoint.host and request:
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
endpoint = Endpoint.handle_endpoint_host(endpoint, request)
return endpoint