mirror of https://github.com/jumpserver/jumpserver
perf: 去掉 connect token endpoint protocol
parent
dd207016b2
commit
44ee80b05a
|
@ -19,12 +19,13 @@ from common.utils import random_string
|
|||
from common.utils.django import get_request_os
|
||||
from orgs.mixins.api import RootOrgViewMixin
|
||||
from perms.models import ActionChoices
|
||||
from terminal.const import NativeClient
|
||||
from terminal.const import NativeClient, TerminalType
|
||||
from terminal.models import EndpointRule, Applet
|
||||
from ..models import ConnectionToken
|
||||
from ..serializers import (
|
||||
ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
|
||||
SuperConnectionTokenSerializer, )
|
||||
SuperConnectionTokenSerializer,
|
||||
)
|
||||
|
||||
__all__ = ['ConnectionTokenViewSet', 'SuperConnectionTokenViewSet']
|
||||
|
||||
|
@ -143,9 +144,12 @@ class RDPFileClientProtocolURLMixin:
|
|||
def get_client_protocol_data(self, token: ConnectionToken):
|
||||
_os = get_request_os(self.request)
|
||||
|
||||
connect_method = getattr(NativeClient, token.connect_method, None)
|
||||
if connect_method is None:
|
||||
raise ValueError('Connect method not support: {}'.format(token.connect_method))
|
||||
connect_method_name = token.connect_method
|
||||
connect_method_dict = TerminalType.get_connect_method(
|
||||
token.connect_method, token.protocol, _os
|
||||
)
|
||||
if connect_method_dict is None:
|
||||
raise ValueError('Connect method not support: {}'.format(connect_method_name))
|
||||
|
||||
data = {
|
||||
'id': str(token.id),
|
||||
|
@ -154,7 +158,7 @@ class RDPFileClientProtocolURLMixin:
|
|||
'file': {}
|
||||
}
|
||||
|
||||
if connect_method == NativeClient.mstsc:
|
||||
if connect_method_name == NativeClient.mstsc:
|
||||
filename, content = self.get_rdp_file_info(token)
|
||||
data.update({
|
||||
'file': {
|
||||
|
@ -163,8 +167,11 @@ class RDPFileClientProtocolURLMixin:
|
|||
}
|
||||
})
|
||||
else:
|
||||
endpoint = self.get_smart_endpoint(protocol=token.endpoint_protocol, asset=token.asset)
|
||||
cmd = NativeClient.get_launch_command(connect_method, token, endpoint)
|
||||
endpoint = self.get_smart_endpoint(
|
||||
protocol=connect_method_dict['endpoint_protocol'],
|
||||
asset=token.asset
|
||||
)
|
||||
cmd = NativeClient.get_launch_command(connect_method_name, token, endpoint)
|
||||
data.update({'command': cmd})
|
||||
return data
|
||||
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 3.2.14 on 2022-11-29 13:27
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('authentication', '0018_connectiontoken_endpoint_protocol'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='connectiontoken',
|
||||
name='endpoint_protocol',
|
||||
),
|
||||
]
|
|
@ -35,9 +35,6 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel):
|
|||
choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol")
|
||||
)
|
||||
connect_method = models.CharField(max_length=32, verbose_name=_("Connect method"))
|
||||
endpoint_protocol = models.CharField(
|
||||
choices=Protocol.choices, max_length=16, verbose_name=_("Endpoint protocol")
|
||||
)
|
||||
user_display = models.CharField(max_length=128, default='', verbose_name=_("User display"))
|
||||
asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display"))
|
||||
date_expired = models.DateTimeField(
|
||||
|
|
|
@ -23,7 +23,7 @@ class ConnectionTokenSerializer(OrgResourceModelSerializerMixin):
|
|||
fields_small = fields_mini + [
|
||||
'user', 'asset', 'account_name',
|
||||
'input_username', 'input_secret',
|
||||
'connect_method', 'endpoint_protocol', 'protocol',
|
||||
'connect_method', 'protocol',
|
||||
'actions', 'date_expired', 'date_created',
|
||||
'date_updated', 'created_by',
|
||||
'updated_by', 'org_id', 'org_name',
|
||||
|
|
|
@ -204,6 +204,15 @@ class TerminalType(TextChoices):
|
|||
}
|
||||
return protocols
|
||||
|
||||
@classmethod
|
||||
def get_connect_method(cls, name, protocol, os):
|
||||
methods = cls.get_protocols_connect_methods(os)
|
||||
protocol_methods = methods.get(protocol, [])
|
||||
for method in protocol_methods:
|
||||
if method['value'] == name:
|
||||
return method
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_protocols_connect_methods(cls, os):
|
||||
methods = defaultdict(list)
|
||||
|
|
Loading…
Reference in New Issue