perf: 修改 Connect token 数据结构

pull/9133/head
ibuler 2022-11-29 14:42:04 +08:00
parent e4edf3be02
commit 0981cd1ed1
11 changed files with 129 additions and 89 deletions

View File

@ -16,10 +16,11 @@ from rest_framework.serializers import ValidationError
from common.drf.api import JMSModelViewSet
from common.http import is_true
from common.utils import random_string
from common.utils.django import get_request_os
from orgs.mixins.api import RootOrgViewMixin
from perms.models import ActionChoices
from terminal.models import EndpointRule
from terminal.const import NativeClient
from terminal.models import EndpointRule
from ..models import ConnectionToken
from ..serializers import (
ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
@ -130,42 +131,32 @@ class RDPFileClientProtocolURLMixin:
return true_value if is_true(os.getenv(env_key, env_default)) else false_value
def get_client_protocol_data(self, token: ConnectionToken):
username = token.user.username
rdp_config = ssh_token = ''
connect_method = token.connect_method
_os = get_request_os(self.request)
if connect_method == NativeClient.ssh:
filename, ssh_token = self.get_ssh_token(token)
elif connect_method == NativeClient.mstsc:
filename, rdp_config = self.get_rdp_file_info(token)
else:
raise ValueError('Protocol not support: {}'.format(connect_method))
connect_method = getattr(NativeClient, token.connect_method, None)
if connect_method is None:
raise ValueError('Connect method not support: {}'.format(token.connect_method))
return {
"filename": filename,
"protocol": token.protocol,
"username": username,
"token": ssh_token,
"config": rdp_config
}
def get_ssh_token(self, token: ConnectionToken):
if token.asset:
name = token.asset.name
else:
name = '*'
prefix_name = f'{token.user.username}-{name}'
filename = self.get_connect_filename(prefix_name)
endpoint = self.get_smart_endpoint(protocol='ssh', asset=token.asset)
data = {
'ip': endpoint.host,
'port': str(endpoint.ssh_port),
'username': 'JMS-{}'.format(str(token.id)),
'password': token.value
'id': str(token.id),
'value': token.value,
'cmd': '',
'file': {}
}
token = json.dumps(data)
return filename, token
if connect_method == NativeClient.mstsc:
filename, content = self.get_rdp_file_info(token)
data.update({
'file': {
'filename': filename,
'content': content,
}
})
else:
endpoint = self.get_smart_endpoint(protocol=token.endpoint_protocol, asset=token.asset)
cmd = NativeClient.get_launch_command(connect_method, token, endpoint)
data.update({'cmd': cmd})
return data
def get_smart_endpoint(self, protocol, asset=None):
target_ip = asset.get_target_ip() if asset else ''
@ -223,6 +214,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
'get_secret_detail': ConnectionTokenSecretSerializer,
}
rbac_perms = {
'list': 'authentication.view_connectiontoken',
'retrieve': 'authentication.view_connectiontoken',
'create': 'authentication.add_connectiontoken',
'expire': 'authentication.add_connectiontoken',
@ -252,9 +244,9 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
return Response(serializer.data, status=status.HTTP_200_OK)
def get_queryset(self):
queryset = ConnectionToken.objects\
.filter(user=self.request.user)\
.filter(date_expired__lt=timezone.now())
queryset = ConnectionToken.objects \
.filter(user=self.request.user) \
.filter(date_expired__gt=timezone.now())
return queryset
def get_user(self, serializer):

View File

@ -1,11 +1,11 @@
# Generated by Django 3.2.14 on 2022-11-25 14:40
import common.db.fields
from django.db import migrations, models
import common.db.fields
class Migration(migrations.Migration):
dependencies = [
('authentication', '0015_alter_connectiontoken_login'),
]
@ -36,4 +36,15 @@ class Migration(migrations.Migration):
name='value',
field=models.CharField(default='', max_length=64, verbose_name='Value'),
),
migrations.AddField(
model_name='connectiontoken',
name='input_secret',
field=common.db.fields.EncryptCharField(blank=True, default='', max_length=128,
verbose_name='Input Secret'),
),
migrations.AlterField(
model_name='connectiontoken',
name='input_username',
field=models.CharField(blank=True, default='', max_length=128, verbose_name='Input Username'),
),
]

View File

@ -1,11 +1,9 @@
# Generated by Django 3.2.14 on 2022-11-28 10:39
import common.db.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authentication', '0016_auto_20221125_2240'),
]
@ -17,15 +15,4 @@ class Migration(migrations.Migration):
field=models.CharField(default='web_ui', max_length=32, verbose_name='Connect method'),
preserve_default=False,
),
migrations.AddField(
model_name='connectiontoken',
name='input_secret',
field=common.db.fields.EncryptCharField(blank=True, default='', max_length=128,
verbose_name='Input Secret'),
),
migrations.AlterField(
model_name='connectiontoken',
name='input_username',
field=models.CharField(blank=True, default='', max_length=128, verbose_name='Input Username'),
),
]

View File

@ -0,0 +1,19 @@
# Generated by Django 3.2.14 on 2022-11-29 04:49
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authentication', '0017_auto_20221128_1839'),
]
operations = [
migrations.AddField(
model_name='connectiontoken',
name='endpoint_protocol',
field=models.CharField(choices=[('ssh', 'SSH'), ('rdp', 'RDP'), ('telnet', 'Telnet'), ('vnc', 'VNC'), ('mysql', 'MySQL'), ('mariadb', 'MariaDB'), ('oracle', 'Oracle'), ('postgresql', 'PostgreSQL'), ('sqlserver', 'SQLServer'), ('redis', 'Redis'), ('mongodb', 'MongoDB'), ('k8s', 'K8S'), ('http', 'HTTP'), ('None', ' Settings')], default='', max_length=16, verbose_name='Endpoint protocol'),
preserve_default=False,
),
]

View File

@ -35,6 +35,9 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel):
choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol")
)
connect_method = models.CharField(max_length=32, verbose_name=_("Connect method"))
endpoint_protocol = models.CharField(
choices=Protocol.choices, max_length=16, verbose_name=_("Endpoint protocol")
)
user_display = models.CharField(max_length=128, default='', verbose_name=_("User display"))
asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display"))
date_expired = models.DateTimeField(

View File

@ -1,7 +1,7 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from assets.models import Asset, Domain, CommandFilterRule, Account, Platform
from assets.models import Asset, CommandFilterRule, Account, Platform
from assets.serializers import PlatformSerializer, AssetProtocolsSerializer
from authentication.models import ConnectionToken
from orgs.mixins.serializers import OrgResourceModelSerializerMixin
@ -21,21 +21,19 @@ class ConnectionTokenSerializer(OrgResourceModelSerializerMixin):
model = ConnectionToken
fields_mini = ['id', 'value']
fields_small = fields_mini + [
'protocol', 'account_name',
'user', 'asset', 'account_name',
'input_username', 'input_secret',
'connect_method', 'endpoint_protocol', 'protocol',
'actions', 'date_expired', 'date_created',
'date_updated', 'created_by',
'updated_by', 'org_id', 'org_name',
]
fields_fk = [
'user', 'asset',
]
read_only_fields = [
# 普通 Token 不支持指定 user
'user', 'expire_time',
'user_display', 'asset_display',
]
fields = fields_small + fields_fk + read_only_fields
fields = fields_small + read_only_fields
extra_kwargs = {
'value': {'read_only': True},
}

View File

@ -2,11 +2,11 @@
#
import re
from django.shortcuts import reverse as dj_reverse
from django.conf import settings
from django.utils import timezone
from django.db import models
from django.db.models.signals import post_save, pre_save
from django.shortcuts import reverse as dj_reverse
from django.utils import timezone
UUID_PATTERN = re.compile(r'[0-9a-zA-Z\-]{36}')
@ -80,3 +80,18 @@ def bulk_create_with_signal(cls: models.Model, items, **kwargs):
for i in items:
post_save.send(sender=cls, instance=i, created=True)
return result
def get_request_os(request):
"""获取请求的操作系统"""
agent = request.META.get('HTTP_USER_AGENT', '').lower()
if agent is None:
return 'unknown'
if 'windows' in agent.lower():
return 'windows'
if 'mac' in agent.lower():
return 'mac'
if 'linux' in agent.lower():
return 'linux'
return 'unknown'

View File

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
#
import time
from email.utils import formatdate
import calendar
import threading
import time
from email.utils import formatdate
_STRPTIME_LOCK = threading.Lock()
@ -35,3 +35,6 @@ def http_to_unixtime(time_string):
def iso8601_to_unixtime(time_string):
"""把ISO8601时间字符串形如2012-02-24T06:07:48.000Z转换为UNIX时间精确到秒。"""
return to_unixtime(time_string, _ISO8601_FORMAT)

View File

@ -12,6 +12,7 @@ from common.drf.api import JMSBulkModelViewSet
from common.exceptions import JMSException
from common.permissions import IsValidUser
from common.permissions import WithBootstrapToken
from common.utils import get_request_os
from terminal import serializers
from terminal.const import TerminalType
from terminal.models import Terminal
@ -77,13 +78,7 @@ class ConnectMethodListApi(generics.ListAPIView):
permission_classes = [IsValidUser]
def get_queryset(self):
user_agent = self.request.META['HTTP_USER_AGENT'].lower()
if 'macintosh' in user_agent:
os = 'macos'
elif 'windows' in user_agent:
os = 'windows'
else:
os = 'linux'
os = get_request_os(self.request)
return TerminalType.get_protocols_connect_methods(os)
def list(self, request, *args, **kwargs):

View File

@ -56,7 +56,11 @@ class NativeClient(TextChoices):
xshell = 'xshell', 'Xshell'
# Magnus
db_client = 'db_client', _('DB Client')
mysql = 'db_client_mysql', _('DB Client')
psql = 'db_client_psql', _('DB Client')
sqlplus = 'db_client_sqlplus', _('DB Client')
redis = 'db_client_redis', _('DB Client')
mongodb = 'db_client_mongodb', _('DB Client')
# Razor
mstsc = 'mstsc', 'Remote Desktop'
@ -69,14 +73,23 @@ class NativeClient(TextChoices):
'windows': [cls.putty],
},
Protocol.rdp: [cls.mstsc],
Protocol.mysql: [cls.db_client],
Protocol.oracle: [cls.db_client],
Protocol.postgresql: [cls.db_client],
Protocol.redis: [cls.db_client],
Protocol.mongodb: [cls.db_client],
Protocol.mysql: [cls.mysql],
Protocol.oracle: [cls.sqlplus],
Protocol.postgresql: [cls.psql],
Protocol.redis: [cls.redis],
Protocol.mongodb: [cls.mongodb],
}
return clients
@classmethod
def get_target_protocol(cls, name, os):
for protocol, clients in cls.get_native_clients().items():
if isinstance(clients, dict):
clients = clients.get(os) or clients.get('default')
if name in clients:
return protocol
return None
@classmethod
def get_methods(cls, os='windows'):
clients_map = cls.get_native_clients()
@ -94,23 +107,18 @@ class NativeClient(TextChoices):
return methods
@classmethod
def get_launch_command(cls, name, os='windows'):
def get_launch_command(cls, name, token, endpoint, os='windows'):
commands = {
cls.ssh: 'ssh {token.id}@{endpoint.ip} -p {endpoint.port}',
cls.putty: 'putty-ssh {token.id}@{endpoint.ip} -P {endpoint.port}',
cls.xshell: 'xshell -url ssh://{token.id}:{token.value}@{endpoint.ip}:{endpoint.port}',
# 'mysql': 'mysql -h {hostname} -P {port} -u {username} -p',
# 'psql': {
cls.ssh: f'ssh {token.id}@{endpoint.host} -p {endpoint.ssh_port}',
cls.putty: f'putty -ssh {token.id}@{endpoint.host} -P {endpoint.ssh_port}',
cls.xshell: f'xshell -url ssh://{token.id}:{token.value}@{endpoint.host}:{endpoint.ssh_port}',
# cls.mysql: 'mysql -h {hostname} -P {port} -u {username} -p',
# cls.psql: {
# 'default': 'psql -h {hostname} -p {port} -U {username} -W',
# 'windows': 'psql /h {hostname} /p {port} /U {username} -W',
# },
# 'sqlplus': 'sqlplus {username}/{password}@{hostname}:{port}',
# 'redis': 'redis-cli -h {hostname} -p {port} -a {password}',
cls.mstsc: {
'command': "$open_file$",
'file': {
}
},
# cls.sqlplus: 'sqlplus {username}/{password}@{hostname}:{port}',
# cls.redis: 'redis-cli -h {hostname} -p {port} -a {password}',
}
command = commands.get(name)
if isinstance(command, dict):
@ -217,19 +225,26 @@ class TerminalType(TextChoices):
methods[protocol.value].append({
'value': web_protocol.value,
'label': web_protocol.label,
'endpoint_protocol': 'http',
'type': 'web',
'component': component.value,
})
# Native method
methods[protocol.value].extend([
{'component': component.value, 'type': 'native', **method}
{
'component': component.value,
'type': 'native',
'endpoint_protocol': listen_protocol,
**method
}
for method in native_methods[listen_protocol]
])
for protocol, applet_methods in applet_methods.items():
for method in applet_methods:
method['type'] = 'applet'
method['listen'] = 'rdp'
method['component'] = cls.tinker.value
methods[protocol].extend(applet_methods)
return methods

View File

@ -138,4 +138,6 @@ class TerminalRegistrationSerializer(serializers.ModelSerializer):
class ConnectMethodSerializer(serializers.Serializer):
value = serializers.CharField(max_length=128)
label = serializers.CharField(max_length=128)
group = serializers.CharField(max_length=128)
type = serializers.CharField(max_length=128)
listen = serializers.CharField(max_length=128)
component = serializers.CharField(max_length=128)