Merge pull request #10644 from jumpserver/pr@dev@perf_acls_connect_methods

perf: 优化 connect method acls 和登录 acls
pull/10652/head
老广 2023-06-08 14:52:10 +08:00 committed by GitHub
commit d2f1309900
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 309 additions and 143 deletions

View File

@ -1,4 +1,5 @@
from .command_acl import *
from .connect_method import *
from .login_acl import *
from .login_asset_acl import *
from .login_asset_check import *

View File

@ -1,9 +1,8 @@
from rest_framework.decorators import action
from rest_framework.response import Response
from common.drf.filters import BaseFilterSet
from orgs.mixins.api import OrgBulkModelViewSet
from .common import ACLFiltersetMixin
from .common import ACLUserAssetFilterMixin
from .. import models, serializers
__all__ = ['CommandFilterACLViewSet', 'CommandGroupViewSet']
@ -16,10 +15,10 @@ class CommandGroupViewSet(OrgBulkModelViewSet):
serializer_class = serializers.CommandGroupSerializer
class CommandACLFilter(ACLFiltersetMixin, BaseFilterSet):
class CommandACLFilter(ACLUserAssetFilterMixin):
class Meta:
model = models.CommandFilterACL
fields = ['name', 'users', 'assets']
fields = ['name', ]
class CommandFilterACLViewSet(OrgBulkModelViewSet):

View File

@ -1,12 +1,12 @@
from django.db.models import Q
from django_filters import rest_framework as drf_filters
from common.drf.filters import BaseFilterSet
from common.utils import is_uuid
class ACLFiltersetMixin(BaseFilterSet):
class ACLUserFilterMixin(BaseFilterSet):
users = drf_filters.CharFilter(method='filter_user')
assets = drf_filters.CharFilter(method='filter_asset')
@staticmethod
def filter_user(queryset, name, value):
@ -16,12 +16,17 @@ class ACLFiltersetMixin(BaseFilterSet):
if is_uuid(value):
user = User.objects.filter(id=value).first()
else:
user = User.objects.filter(name=value).first()
q = Q(name=value) | Q(username=value)
user = User.objects.filter(q).first()
if not user:
return queryset.none()
q = queryset.model.users.get_filter_q(user)
return queryset.filter(q).distinct()
class ACLUserAssetFilterMixin(ACLUserFilterMixin):
assets = drf_filters.CharFilter(method='filter_asset')
@staticmethod
def filter_asset(queryset, name, value):
from assets.models import Asset
@ -31,7 +36,8 @@ class ACLFiltersetMixin(BaseFilterSet):
if is_uuid(value):
asset = Asset.objects.filter(id=value).first()
else:
asset = Asset.objects.filter(name=value).first()
q = Q(name=value) | Q(address=value)
asset = Asset.objects.filter(q).first()
if not asset:
return queryset.none()

View File

@ -0,0 +1,23 @@
from django_filters import rest_framework as drf_filters
from common.api import JMSBulkModelViewSet
from .common import ACLUserFilterMixin
from .. import serializers
from ..models import ConnectMethodACL
__all__ = ['ConnectMethodACLViewSet']
class ConnectMethodFilter(ACLUserFilterMixin):
methods = drf_filters.CharFilter(field_name="methods__contains", lookup_expr='exact')
class Meta:
model = ConnectMethodACL
fields = ['name', ]
class ConnectMethodACLViewSet(JMSBulkModelViewSet):
queryset = ConnectMethodACL.objects.all()
filterset_class = ConnectMethodFilter
search_fields = ('name',)
serializer_class = serializers.ConnectMethodACLSerializer

View File

@ -1,13 +1,19 @@
from common.api import JMSBulkModelViewSet
from .common import ACLUserFilterMixin
from .. import serializers
from ..filters import LoginAclFilter
from ..models import LoginACL
__all__ = ['LoginACLViewSet']
class LoginACLFilter(ACLUserFilterMixin):
class Meta:
model = LoginACL
fields = ('name', 'action')
class LoginACLViewSet(JMSBulkModelViewSet):
queryset = LoginACL.objects.all()
filterset_class = LoginAclFilter
filterset_class = LoginACLFilter
search_fields = ('name',)
serializer_class = serializers.LoginACLSerializer

View File

@ -1,19 +1,18 @@
from common.drf.filters import BaseFilterSet
from orgs.mixins.api import OrgBulkModelViewSet
from .common import ACLFiltersetMixin
from .common import ACLUserAssetFilterMixin
from .. import models, serializers
__all__ = ['LoginAssetACLViewSet']
class CommandACLFilter(ACLFiltersetMixin, BaseFilterSet):
class LoginAssetACLFilter(ACLUserAssetFilterMixin):
class Meta:
model = models.LoginAssetACL
fields = ['name', 'users', 'assets']
fields = ['name', ]
class LoginAssetACLViewSet(OrgBulkModelViewSet):
model = models.LoginAssetACL
filterset_class = CommandACLFilter
filterset_class = LoginAssetACLFilter
search_fields = ['name']
serializer_class = serializers.LoginAssetACLSerializer

View File

@ -1,15 +0,0 @@
from django_filters import rest_framework as filters
from common.drf.filters import BaseFilterSet
from acls.models import LoginACL
class LoginAclFilter(BaseFilterSet):
user = filters.UUIDFilter(field_name='user_id')
user_display = filters.CharFilter(field_name='user__name')
class Meta:
model = LoginACL
fields = (
'name', 'user', 'user_display', 'action'
)

View File

@ -0,0 +1,46 @@
# Generated by Django 3.2.17 on 2023-06-06 06:23
import uuid
import django.core.validators
from django.conf import settings
from django.db import migrations, models
import common.db.fields
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('acls', '0014_loginassetacl_rules'),
]
operations = [
migrations.CreateModel(
name='ConnectMethodACL',
fields=[
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('name', models.CharField(max_length=128, unique=True, verbose_name='Name')),
('priority', models.IntegerField(default=50, help_text='1-100, the lower the value will be match first',
validators=[django.core.validators.MinValueValidator(1),
django.core.validators.MaxValueValidator(100)],
verbose_name='Priority')),
('action', models.CharField(default='reject', max_length=64, verbose_name='Action')),
('is_active', models.BooleanField(default=True, verbose_name='Active')),
('users', common.db.fields.JSONManyToManyField(default=dict, to='users.User', verbose_name='Users')),
('connect_methods', models.JSONField(default=list, verbose_name='Connect methods')),
(
'reviewers',
models.ManyToManyField(blank=True, to=settings.AUTH_USER_MODEL, verbose_name='Reviewers')),
],
options={
'ordering': ('priority', 'date_updated', 'name'),
'abstract': False,
},
),
]

View File

@ -0,0 +1,37 @@
# Generated by Django 3.2.17 on 2023-06-06 10:57
from django.db import migrations, models
import common.db.fields
def migrate_users_login_acls(apps, schema_editor):
login_acl_model = apps.get_model('acls', 'LoginACL')
for login_acl in login_acl_model.objects.all():
login_acl.users = {
"type": "ids", "ids": [str(login_acl.user_id)]
}
login_acl.save()
class Migration(migrations.Migration):
dependencies = [
('acls', '0015_connectmethodacl'),
]
operations = [
migrations.AddField(
model_name='loginacl',
name='users',
field=common.db.fields.JSONManyToManyField(default=dict, to='users.User', verbose_name='Users'),
),
migrations.RemoveField(
model_name='loginacl',
name='user',
),
migrations.AlterField(
model_name='loginacl',
name='name',
field=models.CharField(max_length=128, unique=True, verbose_name='Name'),
),
]

View File

@ -1,3 +1,4 @@
from .command_acl import *
from .connect_method import *
from .login_acl import *
from .login_asset_acl import *
from .command_acl import *

View File

@ -9,7 +9,7 @@ from common.utils.time_period import contains_time_period
from orgs.mixins.models import OrgModelMixin
__all__ = [
'BaseACL', 'UserAssetAccountBaseACL',
'BaseACL', 'UserBaseACL', 'UserAssetAccountBaseACL',
]
@ -34,7 +34,7 @@ class BaseACLQuerySet(models.QuerySet):
class BaseACL(JMSBaseModel):
name = models.CharField(max_length=128, verbose_name=_('Name'))
name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
priority = models.IntegerField(
default=50, verbose_name=_("Priority"),
help_text=_("1-100, the lower the value will be match first"),
@ -79,13 +79,27 @@ class BaseACL(JMSBaseModel):
return None
class UserAssetAccountBaseACL(BaseACL, OrgModelMixin):
class UserBaseACL(BaseACL):
users = JSONManyToManyField('users.User', default=dict, verbose_name=_('Users'))
class Meta:
abstract = True
@classmethod
def get_user_acls(cls, user):
queryset = cls.objects.all()
q = cls.users.get_filter_q(user)
queryset = queryset.filter(q)
return queryset.valid().distinct()
class UserAssetAccountBaseACL(UserBaseACL, OrgModelMixin):
name = models.CharField(max_length=128, verbose_name=_('Name'))
assets = JSONManyToManyField('assets.Asset', default=dict, verbose_name=_('Assets'))
accounts = models.JSONField(default=list, verbose_name=_("Accounts"))
class Meta(BaseACL.Meta):
unique_together = ('name', 'org_id')
class Meta(UserBaseACL.Meta):
unique_together = [('name', 'org_id')]
abstract = True
@classmethod

View File

@ -0,0 +1,10 @@
from django.db import models
from django.utils.translation import gettext_lazy as _
from .base import UserBaseACL
__all__ = ['ConnectMethodACL']
class ConnectMethodACL(UserBaseACL):
connect_methods = models.JSONField(default=list, verbose_name=_('Connect methods'))

View File

@ -3,18 +3,14 @@ from django.utils.translation import ugettext_lazy as _
from common.utils import get_request_ip, get_ip_city
from common.utils.timezone import local_now_display
from .base import BaseACL
from .base import UserBaseACL
class LoginACL(BaseACL):
user = models.ForeignKey(
'users.User', on_delete=models.CASCADE,
related_name='login_acls', verbose_name=_('User')
)
class LoginACL(UserBaseACL):
# 规则, ip_group, time_period
rules = models.JSONField(default=dict, verbose_name=_('Rule'))
class Meta(BaseACL.Meta):
class Meta(UserBaseACL.Meta):
verbose_name = _('Login acl')
abstract = False
@ -28,10 +24,6 @@ class LoginACL(BaseACL):
def filter_acl(cls, user):
return user.login_acls.all().valid().distinct()
@classmethod
def get_user_acls(cls, user):
return cls.filter_acl(user)
def create_confirm_ticket(self, request):
from tickets import const
from tickets.models import ApplyLoginTicket

View File

@ -1,4 +1,5 @@
from .command_acl import *
from .connect_method import *
from .login_acl import *
from .login_asset_acl import *
from .login_asset_check import *
from .command_acl import *

View File

@ -1,11 +1,10 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from acls.models.base import ActionChoices
from acls.models.base import ActionChoices, BaseACL
from common.serializers.fields import JSONManyToManyField, LabeledChoiceField
from jumpserver.utils import has_valid_xpack_license
from common.serializers.fields import JSONManyToManyField, ObjectRelatedField, LabeledChoiceField
from orgs.models import Organization
from users.models import User
common_help_text = _(
"With * indicating a match all. "
@ -71,25 +70,16 @@ class ActionAclSerializer(serializers.Serializer):
action._choices = choices
class BaseUserAssetAccountACLSerializerMixin(ActionAclSerializer, serializers.Serializer):
users = JSONManyToManyField(label=_('User'))
assets = JSONManyToManyField(label=_('Asset'))
accounts = serializers.ListField(label=_('Account'))
reviewers = ObjectRelatedField(
queryset=User.objects, many=True, required=False, label=_('Reviewers')
)
reviewers_amount = serializers.IntegerField(
read_only=True, source="reviewers.count", label=_('Reviewers amount')
)
class BaserACLSerializer(ActionAclSerializer, serializers.Serializer):
class Meta:
model = BaseACL
fields_mini = ["id", "name"]
fields_small = fields_mini + [
"users", "accounts", "assets", "is_active",
"date_created", "date_updated", "priority",
"action", "comment", "created_by", "org_id",
"is_active", "priority", "action",
"date_created", "date_updated",
"comment", "created_by", "org_id",
]
fields_m2m = ["reviewers", "reviewers_amount"]
fields_m2m = ["reviewers", ]
fields = fields_small + fields_m2m
extra_kwargs = {
"priority": {"default": 50},
@ -115,3 +105,18 @@ class BaseUserAssetAccountACLSerializerMixin(ActionAclSerializer, serializers.Se
)
raise serializers.ValidationError(error)
return valid_reviewers
class BaserUserACLSerializer(BaserACLSerializer):
users = JSONManyToManyField(label=_('User'))
class Meta(BaserACLSerializer.Meta):
fields = BaserACLSerializer.Meta.fields + ['users']
class BaseUserAssetAccountACLSerializer(BaserUserACLSerializer):
assets = JSONManyToManyField(label=_('Asset'))
accounts = serializers.ListField(label=_('Account'))
class Meta(BaserUserACLSerializer.Meta):
fields = BaserUserACLSerializer.Meta.fields + ['assets', 'accounts']

View File

@ -1,13 +1,13 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from terminal.models import Session
from acls.models import CommandGroup, CommandFilterACL
from common.utils import lazyproperty, get_object_or_none
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from orgs.utils import tmp_to_root_org
from common.utils import lazyproperty, get_object_or_none
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaseUserAssetAccountACLSerializerMixin as BaseSerializer
from orgs.utils import tmp_to_root_org
from terminal.models import Session
from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
__all__ = ["CommandFilterACLSerializer", "CommandGroupSerializer", "CommandReviewSerializer"]

View File

@ -0,0 +1,23 @@
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
from ..models import ConnectMethodACL
__all__ = ["ConnectMethodACLSerializer"]
class ConnectMethodACLSerializer(BaseSerializer, BulkOrgResourceModelSerializer):
class Meta(BaseSerializer.Meta):
model = ConnectMethodACL
fields = [
i for i in BaseSerializer.Meta.fields + ['connect_methods']
if i not in ['assets', 'accounts']
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
field_action = self.fields.get('action')
if not field_action:
return
# 仅支持拒绝
for k in ['review', 'accept']:
field_action._choices.pop(k, None)

View File

@ -1,47 +1,22 @@
from django.utils.translation import ugettext as _
from rest_framework import serializers
from common.serializers import BulkModelSerializer, MethodSerializer
from common.serializers.fields import ObjectRelatedField
from users.models import User
from .base import ActionAclSerializer
from common.serializers import MethodSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaserUserACLSerializer
from .rules import RuleSerializer
from ..models import LoginACL
__all__ = [
"LoginACLSerializer",
]
__all__ = ["LoginACLSerializer"]
common_help_text = _(
"With * indicating a match all. "
)
common_help_text = _("With * indicating a match all. ")
class LoginACLSerializer(ActionAclSerializer, BulkModelSerializer):
user = ObjectRelatedField(queryset=User.objects, label=_("User"))
reviewers = ObjectRelatedField(
queryset=User.objects, label=_("Reviewers"), many=True, required=False
)
reviewers_amount = serializers.IntegerField(
read_only=True, source="reviewers.count", label=_("Reviewers amount")
)
class LoginACLSerializer(BaserUserACLSerializer, BulkOrgResourceModelSerializer):
rules = MethodSerializer(label=_('Rule'))
class Meta:
class Meta(BaserUserACLSerializer.Meta):
model = LoginACL
fields_mini = ["id", "name"]
fields_small = fields_mini + [
"priority", "user", "rules", "action",
"is_active", "date_created", "date_updated",
"comment", "created_by",
]
fields_fk = ["user"]
fields_m2m = ["reviewers", "reviewers_amount"]
fields = fields_small + fields_fk + fields_m2m
extra_kwargs = {
"priority": {"default": 50},
"is_active": {"default": True},
}
fields = BaserUserACLSerializer.Meta.fields + ['rules', ]
def get_rules_serializer(self):
return RuleSerializer()

View File

@ -2,7 +2,7 @@ from django.utils.translation import gettext_lazy as _
from common.serializers import MethodSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaseUserAssetAccountACLSerializerMixin as BaseSerializer
from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
from .rules import RuleSerializer
from ..models import LoginAssetACL

View File

@ -10,6 +10,7 @@ router.register(r'login-acls', api.LoginACLViewSet, 'login-acl')
router.register(r'login-asset-acls', api.LoginAssetACLViewSet, 'login-asset-acl')
router.register(r'command-filter-acls', api.CommandFilterACLViewSet, 'command-filter-acl')
router.register(r'command-groups', api.CommandGroupViewSet, 'command-group')
router.register(r'connect-method-acls', api.ConnectMethodACLViewSet, 'connect-method-acl')
urlpatterns = [
path('login-asset/check/', api.LoginAssetCheckAPI.as_view(), name='login-asset-check'),

View File

@ -1,21 +1,21 @@
# -*- coding: utf-8 -*-
#
import re
from django.templatetags.static import static
from collections import OrderedDict
from itertools import chain
import logging
import datetime
import uuid
from functools import wraps
import time
import ipaddress
import psutil
import platform
import logging
import os
import platform
import re
import socket
import time
import uuid
from collections import OrderedDict
from functools import wraps
from itertools import chain
import psutil
from django.conf import settings
from django.templatetags.static import static
UUID_PATTERN = re.compile(r'\w{8}(-\w{4}){3}-\w{12}')
ipip_db = None
@ -76,6 +76,7 @@ def setattr_bulk(seq, key, value):
def set_attr(obj):
setattr(obj, key, value)
return obj
return map(set_attr, seq)
@ -97,12 +98,12 @@ def capacity_convert(size, expect='auto', rate=1000):
rate_mapping = (
('K', rate),
('KB', rate),
('M', rate**2),
('MB', rate**2),
('G', rate**3),
('GB', rate**3),
('T', rate**4),
('TB', rate**4),
('M', rate ** 2),
('MB', rate ** 2),
('G', rate ** 3),
('GB', rate ** 3),
('T', rate ** 4),
('TB', rate ** 4),
)
rate_mapping = OrderedDict(rate_mapping)
@ -117,7 +118,7 @@ def capacity_convert(size, expect='auto', rate=1000):
if expect == 'auto':
for unit, rate_ in rate_mapping.items():
if rate > std_size/rate_ >= 1 or unit == "T":
if rate > std_size / rate_ >= 1 or unit == "T":
expect = unit
break
@ -195,6 +196,7 @@ def with_cache(func):
res = func(*args, **kwargs)
cache[key] = res
return res
return wrapper
@ -216,6 +218,7 @@ def timeit(func):
msg = "End call {}, using: {:.1f}ms".format(name, using)
logger.debug(msg)
return result
return wrapper
@ -310,7 +313,7 @@ class Time:
def print(self):
last, *timestamps = self._timestamps
for timestamp, msg in zip(timestamps, self._msgs):
logger.debug(f'TIME_IT: {msg} {timestamp-last}')
logger.debug(f'TIME_IT: {msg} {timestamp - last}')
last = timestamp
@ -367,7 +370,7 @@ def pretty_string(data, max_length=128, ellipsis_str='...'):
def group_by_count(it, count):
return [it[i:i+count] for i in range(0, len(it), count)]
return [it[i:i + count] for i in range(0, len(it), count)]
def test_ip_connectivity(host, port, timeout=0.5):
@ -395,3 +398,17 @@ def static_or_direct(logo_path):
def make_dirs(name, mode=0o755, exist_ok=False):
""" 默认权限设置为 0o755 """
return os.makedirs(name, mode=mode, exist_ok=exist_ok)
def distinct(seq, key=None):
if key is None:
# 如果未提供关键字参数,则默认使用元素本身作为比较键
key = lambda x: x
seen = set()
result = []
for item in seq:
k = key(item)
if k not in seen:
seen.add(k)
result.append(item)
return result

View File

@ -148,6 +148,8 @@ only_system_permissions = (
('orgs', 'organization', 'view', 'rootorg'),
('terminal', 'applet', '*', '*'),
('terminal', 'applethost', '*', '*'),
('acls', 'loginacl', '*', '*'),
('acls', 'connectmethodacl', '*', '*')
)
only_org_permissions = (

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
#
import itertools
from rest_framework import generics
from rest_framework.views import Response
from common.permissions import IsValidUser
from common.utils import get_request_os
from common.utils import get_request_os, is_true, distinct
from terminal import serializers
from terminal.connect_methods import ConnectMethodUtil
@ -16,9 +17,29 @@ class ConnectMethodListApi(generics.ListAPIView):
serializer_class = serializers.ConnectMethodSerializer
permission_classes = [IsValidUser]
def filter_user_connect_methods(self, d):
from acls.models import ConnectMethodACL
# 这里要根据用户来了,受 acl 影响
acls = ConnectMethodACL.get_user_acls(self.request.user)
disabled_connect_methods = acls.values_list('connect_methods', flat=True)
disabled_connect_methods = set(itertools.chain.from_iterable(disabled_connect_methods))
new_queryset = {}
for protocol, methods in d.items():
new_queryset[protocol] = [x for x in methods if x['value'] not in disabled_connect_methods]
return new_queryset
def get_queryset(self):
os = get_request_os(self.request)
return ConnectMethodUtil.get_filtered_protocols_connect_methods(os)
os = self.request.query_params.get('os') or get_request_os(self.request)
queryset = ConnectMethodUtil.get_filtered_protocols_connect_methods(os)
flat = self.request.query_params.get('flat')
# 先这么处理, 这里不用过滤包含的事所有
if is_true(flat):
queryset = itertools.chain.from_iterable(queryset.values())
queryset = distinct(queryset, key=lambda x: x['value'])
else:
queryset = self.filter_queryset(queryset)
return queryset
def list(self, request, *args, **kwargs):
queryset = self.get_queryset()

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
#
import itertools
from collections import defaultdict
from django.conf import settings
@ -51,11 +52,7 @@ class NativeClient(TextChoices):
xshell = 'xshell', 'Xshell'
# Magnus
mysql = 'db_client_mysql', _('DB Client')
psql = 'db_client_psql', _('DB Client')
sqlplus = 'db_client_sqlplus', _('DB Client')
redis = 'db_client_redis', _('DB Client')
mongodb = 'db_client_mongodb', _('DB Client')
db_client = 'db_client', _('DB Client')
# Razor
mstsc = 'mstsc', 'Remote Desktop'
@ -70,12 +67,12 @@ class NativeClient(TextChoices):
'windows': [cls.putty],
},
Protocol.rdp: [cls.mstsc],
Protocol.mysql: [cls.mysql],
Protocol.mariadb: [cls.mysql],
Protocol.oracle: [cls.sqlplus],
Protocol.postgresql: [cls.psql],
Protocol.redis: [cls.redis],
Protocol.mongodb: [cls.mongodb],
Protocol.mysql: [cls.db_client],
Protocol.mariadb: [cls.db_client],
Protocol.oracle: [cls.db_client],
Protocol.postgresql: [cls.db_client],
Protocol.redis: [cls.db_client],
Protocol.mongodb: [cls.db_client],
}
return clients
@ -83,7 +80,10 @@ class NativeClient(TextChoices):
def get_target_protocol(cls, name, os):
for protocol, clients in cls.get_native_clients().items():
if isinstance(clients, dict):
clients = clients.get(os) or clients.get('default')
if os == 'all':
clients = list(itertools.chain(*clients.values()))
else:
clients = clients.get(os) or clients.get('default')
if name in clients:
return protocol
return None
@ -99,7 +99,10 @@ class NativeClient(TextChoices):
for protocol, _clients in clients_map.items():
if isinstance(_clients, dict):
_clients = _clients.get(os, _clients['default'])
if os == 'all':
_clients = list(itertools.chain(*_clients.values()))
else:
_clients = _clients.get(os, _clients['default'])
for client in _clients:
if not settings.XPACK_ENABLED and client in cls.xpack_methods():
continue
@ -245,11 +248,10 @@ class ConnectMethodUtil:
if not getattr(settings, 'TERMINAL_KOKO_SSH_ENABLED'):
protocol = Protocol.ssh
methods[protocol] = [m for m in methods[protocol] if m['type'] != 'native']
return methods
@classmethod
def get_protocols_connect_methods(cls, os):
def get_protocols_connect_methods(cls, os='windows'):
if cls._all_methods.get('os'):
return cls._all_methods['os']
@ -264,7 +266,7 @@ class ConnectMethodUtil:
for protocol in support:
# Web 方式
methods[protocol.value].extend([
methods[str(protocol)].extend([
{
'component': component.value,
'type': 'web',
@ -286,7 +288,7 @@ class ConnectMethodUtil:
if component == TerminalType.koko and protocol.value != Protocol.ssh:
# koko 仅支持 ssh 的 native 方式,其他数据库的 native 方式不提供
continue
methods[protocol.value].extend([
methods[str(protocol)].extend([
{
'component': component.value,
'type': 'native',