diff --git a/apps/authentication/api/connection_token.py b/apps/authentication/api/connection_token.py index 3e1a61a86..67c732519 100644 --- a/apps/authentication/api/connection_token.py +++ b/apps/authentication/api/connection_token.py @@ -19,12 +19,13 @@ from common.utils import random_string from common.utils.django import get_request_os from orgs.mixins.api import RootOrgViewMixin from perms.models import ActionChoices -from terminal.const import NativeClient +from terminal.const import NativeClient, TerminalType from terminal.models import EndpointRule, Applet from ..models import ConnectionToken from ..serializers import ( ConnectionTokenSerializer, ConnectionTokenSecretSerializer, - SuperConnectionTokenSerializer, ) + SuperConnectionTokenSerializer, +) __all__ = ['ConnectionTokenViewSet', 'SuperConnectionTokenViewSet'] @@ -143,9 +144,12 @@ class RDPFileClientProtocolURLMixin: def get_client_protocol_data(self, token: ConnectionToken): _os = get_request_os(self.request) - connect_method = getattr(NativeClient, token.connect_method, None) - if connect_method is None: - raise ValueError('Connect method not support: {}'.format(token.connect_method)) + connect_method_name = token.connect_method + connect_method_dict = TerminalType.get_connect_method( + token.connect_method, token.protocol, _os + ) + if connect_method_dict is None: + raise ValueError('Connect method not support: {}'.format(connect_method_name)) data = { 'id': str(token.id), @@ -154,7 +158,7 @@ class RDPFileClientProtocolURLMixin: 'file': {} } - if connect_method == NativeClient.mstsc: + if connect_method_name == NativeClient.mstsc: filename, content = self.get_rdp_file_info(token) data.update({ 'file': { @@ -163,8 +167,11 @@ class RDPFileClientProtocolURLMixin: } }) else: - endpoint = self.get_smart_endpoint(protocol=token.endpoint_protocol, asset=token.asset) - cmd = NativeClient.get_launch_command(connect_method, token, endpoint) + endpoint = self.get_smart_endpoint( + protocol=connect_method_dict['endpoint_protocol'], + asset=token.asset + ) + cmd = NativeClient.get_launch_command(connect_method_name, token, endpoint) data.update({'command': cmd}) return data diff --git a/apps/authentication/migrations/0019_remove_connectiontoken_endpoint_protocol.py b/apps/authentication/migrations/0019_remove_connectiontoken_endpoint_protocol.py new file mode 100644 index 000000000..ef7f401bd --- /dev/null +++ b/apps/authentication/migrations/0019_remove_connectiontoken_endpoint_protocol.py @@ -0,0 +1,17 @@ +# Generated by Django 3.2.14 on 2022-11-29 13:27 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentication', '0018_connectiontoken_endpoint_protocol'), + ] + + operations = [ + migrations.RemoveField( + model_name='connectiontoken', + name='endpoint_protocol', + ), + ] diff --git a/apps/authentication/models/connection_token.py b/apps/authentication/models/connection_token.py index 7f4e7f42b..5505f81a3 100644 --- a/apps/authentication/models/connection_token.py +++ b/apps/authentication/models/connection_token.py @@ -35,9 +35,6 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel): choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol") ) connect_method = models.CharField(max_length=32, verbose_name=_("Connect method")) - endpoint_protocol = models.CharField( - choices=Protocol.choices, max_length=16, verbose_name=_("Endpoint protocol") - ) user_display = models.CharField(max_length=128, default='', verbose_name=_("User display")) asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display")) date_expired = models.DateTimeField( diff --git a/apps/authentication/serializers/connection_token.py b/apps/authentication/serializers/connection_token.py index e851ceadf..16ef7dc1b 100644 --- a/apps/authentication/serializers/connection_token.py +++ b/apps/authentication/serializers/connection_token.py @@ -23,7 +23,7 @@ class ConnectionTokenSerializer(OrgResourceModelSerializerMixin): fields_small = fields_mini + [ 'user', 'asset', 'account_name', 'input_username', 'input_secret', - 'connect_method', 'endpoint_protocol', 'protocol', + 'connect_method', 'protocol', 'actions', 'date_expired', 'date_created', 'date_updated', 'created_by', 'updated_by', 'org_id', 'org_name', diff --git a/apps/terminal/const.py b/apps/terminal/const.py index 40c89ff24..dbfcc0dec 100644 --- a/apps/terminal/const.py +++ b/apps/terminal/const.py @@ -204,6 +204,15 @@ class TerminalType(TextChoices): } return protocols + @classmethod + def get_connect_method(cls, name, protocol, os): + methods = cls.get_protocols_connect_methods(os) + protocol_methods = methods.get(protocol, []) + for method in protocol_methods: + if method['value'] == name: + return method + return None + @classmethod def get_protocols_connect_methods(cls, os): methods = defaultdict(list)