Merge pull request #10644 from jumpserver/pr@dev@perf_acls_connect_methods

perf: 优化 connect method acls 和登录 acls
pull/10652/head
老广 2023-06-08 14:52:10 +08:00 committed by GitHub
commit d2f1309900
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 309 additions and 143 deletions

View File

@ -1,4 +1,5 @@
from .command_acl import * from .command_acl import *
from .connect_method import *
from .login_acl import * from .login_acl import *
from .login_asset_acl import * from .login_asset_acl import *
from .login_asset_check import * from .login_asset_check import *

View File

@ -1,9 +1,8 @@
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from common.drf.filters import BaseFilterSet
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from .common import ACLFiltersetMixin from .common import ACLUserAssetFilterMixin
from .. import models, serializers from .. import models, serializers
__all__ = ['CommandFilterACLViewSet', 'CommandGroupViewSet'] __all__ = ['CommandFilterACLViewSet', 'CommandGroupViewSet']
@ -16,10 +15,10 @@ class CommandGroupViewSet(OrgBulkModelViewSet):
serializer_class = serializers.CommandGroupSerializer serializer_class = serializers.CommandGroupSerializer
class CommandACLFilter(ACLFiltersetMixin, BaseFilterSet): class CommandACLFilter(ACLUserAssetFilterMixin):
class Meta: class Meta:
model = models.CommandFilterACL model = models.CommandFilterACL
fields = ['name', 'users', 'assets'] fields = ['name', ]
class CommandFilterACLViewSet(OrgBulkModelViewSet): class CommandFilterACLViewSet(OrgBulkModelViewSet):

View File

@ -1,12 +1,12 @@
from django.db.models import Q
from django_filters import rest_framework as drf_filters from django_filters import rest_framework as drf_filters
from common.drf.filters import BaseFilterSet from common.drf.filters import BaseFilterSet
from common.utils import is_uuid from common.utils import is_uuid
class ACLFiltersetMixin(BaseFilterSet): class ACLUserFilterMixin(BaseFilterSet):
users = drf_filters.CharFilter(method='filter_user') users = drf_filters.CharFilter(method='filter_user')
assets = drf_filters.CharFilter(method='filter_asset')
@staticmethod @staticmethod
def filter_user(queryset, name, value): def filter_user(queryset, name, value):
@ -16,12 +16,17 @@ class ACLFiltersetMixin(BaseFilterSet):
if is_uuid(value): if is_uuid(value):
user = User.objects.filter(id=value).first() user = User.objects.filter(id=value).first()
else: else:
user = User.objects.filter(name=value).first() q = Q(name=value) | Q(username=value)
user = User.objects.filter(q).first()
if not user: if not user:
return queryset.none() return queryset.none()
q = queryset.model.users.get_filter_q(user) q = queryset.model.users.get_filter_q(user)
return queryset.filter(q).distinct() return queryset.filter(q).distinct()
class ACLUserAssetFilterMixin(ACLUserFilterMixin):
assets = drf_filters.CharFilter(method='filter_asset')
@staticmethod @staticmethod
def filter_asset(queryset, name, value): def filter_asset(queryset, name, value):
from assets.models import Asset from assets.models import Asset
@ -31,7 +36,8 @@ class ACLFiltersetMixin(BaseFilterSet):
if is_uuid(value): if is_uuid(value):
asset = Asset.objects.filter(id=value).first() asset = Asset.objects.filter(id=value).first()
else: else:
asset = Asset.objects.filter(name=value).first() q = Q(name=value) | Q(address=value)
asset = Asset.objects.filter(q).first()
if not asset: if not asset:
return queryset.none() return queryset.none()

View File

@ -0,0 +1,23 @@
from django_filters import rest_framework as drf_filters
from common.api import JMSBulkModelViewSet
from .common import ACLUserFilterMixin
from .. import serializers
from ..models import ConnectMethodACL
__all__ = ['ConnectMethodACLViewSet']
class ConnectMethodFilter(ACLUserFilterMixin):
methods = drf_filters.CharFilter(field_name="methods__contains", lookup_expr='exact')
class Meta:
model = ConnectMethodACL
fields = ['name', ]
class ConnectMethodACLViewSet(JMSBulkModelViewSet):
queryset = ConnectMethodACL.objects.all()
filterset_class = ConnectMethodFilter
search_fields = ('name',)
serializer_class = serializers.ConnectMethodACLSerializer

View File

@ -1,13 +1,19 @@
from common.api import JMSBulkModelViewSet from common.api import JMSBulkModelViewSet
from .common import ACLUserFilterMixin
from .. import serializers from .. import serializers
from ..filters import LoginAclFilter
from ..models import LoginACL from ..models import LoginACL
__all__ = ['LoginACLViewSet'] __all__ = ['LoginACLViewSet']
class LoginACLFilter(ACLUserFilterMixin):
class Meta:
model = LoginACL
fields = ('name', 'action')
class LoginACLViewSet(JMSBulkModelViewSet): class LoginACLViewSet(JMSBulkModelViewSet):
queryset = LoginACL.objects.all() queryset = LoginACL.objects.all()
filterset_class = LoginAclFilter filterset_class = LoginACLFilter
search_fields = ('name',) search_fields = ('name',)
serializer_class = serializers.LoginACLSerializer serializer_class = serializers.LoginACLSerializer

View File

@ -1,19 +1,18 @@
from common.drf.filters import BaseFilterSet
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from .common import ACLFiltersetMixin from .common import ACLUserAssetFilterMixin
from .. import models, serializers from .. import models, serializers
__all__ = ['LoginAssetACLViewSet'] __all__ = ['LoginAssetACLViewSet']
class CommandACLFilter(ACLFiltersetMixin, BaseFilterSet): class LoginAssetACLFilter(ACLUserAssetFilterMixin):
class Meta: class Meta:
model = models.LoginAssetACL model = models.LoginAssetACL
fields = ['name', 'users', 'assets'] fields = ['name', ]
class LoginAssetACLViewSet(OrgBulkModelViewSet): class LoginAssetACLViewSet(OrgBulkModelViewSet):
model = models.LoginAssetACL model = models.LoginAssetACL
filterset_class = CommandACLFilter filterset_class = LoginAssetACLFilter
search_fields = ['name'] search_fields = ['name']
serializer_class = serializers.LoginAssetACLSerializer serializer_class = serializers.LoginAssetACLSerializer

View File

@ -1,15 +0,0 @@
from django_filters import rest_framework as filters
from common.drf.filters import BaseFilterSet
from acls.models import LoginACL
class LoginAclFilter(BaseFilterSet):
user = filters.UUIDFilter(field_name='user_id')
user_display = filters.CharFilter(field_name='user__name')
class Meta:
model = LoginACL
fields = (
'name', 'user', 'user_display', 'action'
)

View File

@ -0,0 +1,46 @@
# Generated by Django 3.2.17 on 2023-06-06 06:23
import uuid
import django.core.validators
from django.conf import settings
from django.db import migrations, models
import common.db.fields
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('acls', '0014_loginassetacl_rules'),
]
operations = [
migrations.CreateModel(
name='ConnectMethodACL',
fields=[
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('name', models.CharField(max_length=128, unique=True, verbose_name='Name')),
('priority', models.IntegerField(default=50, help_text='1-100, the lower the value will be match first',
validators=[django.core.validators.MinValueValidator(1),
django.core.validators.MaxValueValidator(100)],
verbose_name='Priority')),
('action', models.CharField(default='reject', max_length=64, verbose_name='Action')),
('is_active', models.BooleanField(default=True, verbose_name='Active')),
('users', common.db.fields.JSONManyToManyField(default=dict, to='users.User', verbose_name='Users')),
('connect_methods', models.JSONField(default=list, verbose_name='Connect methods')),
(
'reviewers',
models.ManyToManyField(blank=True, to=settings.AUTH_USER_MODEL, verbose_name='Reviewers')),
],
options={
'ordering': ('priority', 'date_updated', 'name'),
'abstract': False,
},
),
]

View File

@ -0,0 +1,37 @@
# Generated by Django 3.2.17 on 2023-06-06 10:57
from django.db import migrations, models
import common.db.fields
def migrate_users_login_acls(apps, schema_editor):
login_acl_model = apps.get_model('acls', 'LoginACL')
for login_acl in login_acl_model.objects.all():
login_acl.users = {
"type": "ids", "ids": [str(login_acl.user_id)]
}
login_acl.save()
class Migration(migrations.Migration):
dependencies = [
('acls', '0015_connectmethodacl'),
]
operations = [
migrations.AddField(
model_name='loginacl',
name='users',
field=common.db.fields.JSONManyToManyField(default=dict, to='users.User', verbose_name='Users'),
),
migrations.RemoveField(
model_name='loginacl',
name='user',
),
migrations.AlterField(
model_name='loginacl',
name='name',
field=models.CharField(max_length=128, unique=True, verbose_name='Name'),
),
]

View File

@ -1,3 +1,4 @@
from .command_acl import *
from .connect_method import *
from .login_acl import * from .login_acl import *
from .login_asset_acl import * from .login_asset_acl import *
from .command_acl import *

View File

@ -9,7 +9,7 @@ from common.utils.time_period import contains_time_period
from orgs.mixins.models import OrgModelMixin from orgs.mixins.models import OrgModelMixin
__all__ = [ __all__ = [
'BaseACL', 'UserAssetAccountBaseACL', 'BaseACL', 'UserBaseACL', 'UserAssetAccountBaseACL',
] ]
@ -34,7 +34,7 @@ class BaseACLQuerySet(models.QuerySet):
class BaseACL(JMSBaseModel): class BaseACL(JMSBaseModel):
name = models.CharField(max_length=128, verbose_name=_('Name')) name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
priority = models.IntegerField( priority = models.IntegerField(
default=50, verbose_name=_("Priority"), default=50, verbose_name=_("Priority"),
help_text=_("1-100, the lower the value will be match first"), help_text=_("1-100, the lower the value will be match first"),
@ -79,13 +79,27 @@ class BaseACL(JMSBaseModel):
return None return None
class UserAssetAccountBaseACL(BaseACL, OrgModelMixin): class UserBaseACL(BaseACL):
users = JSONManyToManyField('users.User', default=dict, verbose_name=_('Users')) users = JSONManyToManyField('users.User', default=dict, verbose_name=_('Users'))
class Meta:
abstract = True
@classmethod
def get_user_acls(cls, user):
queryset = cls.objects.all()
q = cls.users.get_filter_q(user)
queryset = queryset.filter(q)
return queryset.valid().distinct()
class UserAssetAccountBaseACL(UserBaseACL, OrgModelMixin):
name = models.CharField(max_length=128, verbose_name=_('Name'))
assets = JSONManyToManyField('assets.Asset', default=dict, verbose_name=_('Assets')) assets = JSONManyToManyField('assets.Asset', default=dict, verbose_name=_('Assets'))
accounts = models.JSONField(default=list, verbose_name=_("Accounts")) accounts = models.JSONField(default=list, verbose_name=_("Accounts"))
class Meta(BaseACL.Meta): class Meta(UserBaseACL.Meta):
unique_together = ('name', 'org_id') unique_together = [('name', 'org_id')]
abstract = True abstract = True
@classmethod @classmethod

View File

@ -0,0 +1,10 @@
from django.db import models
from django.utils.translation import gettext_lazy as _
from .base import UserBaseACL
__all__ = ['ConnectMethodACL']
class ConnectMethodACL(UserBaseACL):
connect_methods = models.JSONField(default=list, verbose_name=_('Connect methods'))

View File

@ -3,18 +3,14 @@ from django.utils.translation import ugettext_lazy as _
from common.utils import get_request_ip, get_ip_city from common.utils import get_request_ip, get_ip_city
from common.utils.timezone import local_now_display from common.utils.timezone import local_now_display
from .base import BaseACL from .base import UserBaseACL
class LoginACL(BaseACL): class LoginACL(UserBaseACL):
user = models.ForeignKey(
'users.User', on_delete=models.CASCADE,
related_name='login_acls', verbose_name=_('User')
)
# 规则, ip_group, time_period # 规则, ip_group, time_period
rules = models.JSONField(default=dict, verbose_name=_('Rule')) rules = models.JSONField(default=dict, verbose_name=_('Rule'))
class Meta(BaseACL.Meta): class Meta(UserBaseACL.Meta):
verbose_name = _('Login acl') verbose_name = _('Login acl')
abstract = False abstract = False
@ -28,10 +24,6 @@ class LoginACL(BaseACL):
def filter_acl(cls, user): def filter_acl(cls, user):
return user.login_acls.all().valid().distinct() return user.login_acls.all().valid().distinct()
@classmethod
def get_user_acls(cls, user):
return cls.filter_acl(user)
def create_confirm_ticket(self, request): def create_confirm_ticket(self, request):
from tickets import const from tickets import const
from tickets.models import ApplyLoginTicket from tickets.models import ApplyLoginTicket

View File

@ -1,4 +1,5 @@
from .command_acl import *
from .connect_method import *
from .login_acl import * from .login_acl import *
from .login_asset_acl import * from .login_asset_acl import *
from .login_asset_check import * from .login_asset_check import *
from .command_acl import *

View File

@ -1,11 +1,10 @@
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from acls.models.base import ActionChoices from acls.models.base import ActionChoices, BaseACL
from common.serializers.fields import JSONManyToManyField, LabeledChoiceField
from jumpserver.utils import has_valid_xpack_license from jumpserver.utils import has_valid_xpack_license
from common.serializers.fields import JSONManyToManyField, ObjectRelatedField, LabeledChoiceField
from orgs.models import Organization from orgs.models import Organization
from users.models import User
common_help_text = _( common_help_text = _(
"With * indicating a match all. " "With * indicating a match all. "
@ -71,25 +70,16 @@ class ActionAclSerializer(serializers.Serializer):
action._choices = choices action._choices = choices
class BaseUserAssetAccountACLSerializerMixin(ActionAclSerializer, serializers.Serializer): class BaserACLSerializer(ActionAclSerializer, serializers.Serializer):
users = JSONManyToManyField(label=_('User'))
assets = JSONManyToManyField(label=_('Asset'))
accounts = serializers.ListField(label=_('Account'))
reviewers = ObjectRelatedField(
queryset=User.objects, many=True, required=False, label=_('Reviewers')
)
reviewers_amount = serializers.IntegerField(
read_only=True, source="reviewers.count", label=_('Reviewers amount')
)
class Meta: class Meta:
model = BaseACL
fields_mini = ["id", "name"] fields_mini = ["id", "name"]
fields_small = fields_mini + [ fields_small = fields_mini + [
"users", "accounts", "assets", "is_active", "is_active", "priority", "action",
"date_created", "date_updated", "priority", "date_created", "date_updated",
"action", "comment", "created_by", "org_id", "comment", "created_by", "org_id",
] ]
fields_m2m = ["reviewers", "reviewers_amount"] fields_m2m = ["reviewers", ]
fields = fields_small + fields_m2m fields = fields_small + fields_m2m
extra_kwargs = { extra_kwargs = {
"priority": {"default": 50}, "priority": {"default": 50},
@ -115,3 +105,18 @@ class BaseUserAssetAccountACLSerializerMixin(ActionAclSerializer, serializers.Se
) )
raise serializers.ValidationError(error) raise serializers.ValidationError(error)
return valid_reviewers return valid_reviewers
class BaserUserACLSerializer(BaserACLSerializer):
users = JSONManyToManyField(label=_('User'))
class Meta(BaserACLSerializer.Meta):
fields = BaserACLSerializer.Meta.fields + ['users']
class BaseUserAssetAccountACLSerializer(BaserUserACLSerializer):
assets = JSONManyToManyField(label=_('Asset'))
accounts = serializers.ListField(label=_('Account'))
class Meta(BaserUserACLSerializer.Meta):
fields = BaserUserACLSerializer.Meta.fields + ['assets', 'accounts']

View File

@ -1,13 +1,13 @@
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from terminal.models import Session
from acls.models import CommandGroup, CommandFilterACL from acls.models import CommandGroup, CommandFilterACL
from common.utils import lazyproperty, get_object_or_none
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from orgs.utils import tmp_to_root_org from common.utils import lazyproperty, get_object_or_none
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaseUserAssetAccountACLSerializerMixin as BaseSerializer from orgs.utils import tmp_to_root_org
from terminal.models import Session
from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
__all__ = ["CommandFilterACLSerializer", "CommandGroupSerializer", "CommandReviewSerializer"] __all__ = ["CommandFilterACLSerializer", "CommandGroupSerializer", "CommandReviewSerializer"]

View File

@ -0,0 +1,23 @@
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
from ..models import ConnectMethodACL
__all__ = ["ConnectMethodACLSerializer"]
class ConnectMethodACLSerializer(BaseSerializer, BulkOrgResourceModelSerializer):
class Meta(BaseSerializer.Meta):
model = ConnectMethodACL
fields = [
i for i in BaseSerializer.Meta.fields + ['connect_methods']
if i not in ['assets', 'accounts']
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
field_action = self.fields.get('action')
if not field_action:
return
# 仅支持拒绝
for k in ['review', 'accept']:
field_action._choices.pop(k, None)

View File

@ -1,47 +1,22 @@
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from rest_framework import serializers
from common.serializers import BulkModelSerializer, MethodSerializer from common.serializers import MethodSerializer
from common.serializers.fields import ObjectRelatedField from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from users.models import User from .base import BaserUserACLSerializer
from .base import ActionAclSerializer
from .rules import RuleSerializer from .rules import RuleSerializer
from ..models import LoginACL from ..models import LoginACL
__all__ = [ __all__ = ["LoginACLSerializer"]
"LoginACLSerializer",
]
common_help_text = _( common_help_text = _("With * indicating a match all. ")
"With * indicating a match all. "
)
class LoginACLSerializer(ActionAclSerializer, BulkModelSerializer): class LoginACLSerializer(BaserUserACLSerializer, BulkOrgResourceModelSerializer):
user = ObjectRelatedField(queryset=User.objects, label=_("User"))
reviewers = ObjectRelatedField(
queryset=User.objects, label=_("Reviewers"), many=True, required=False
)
reviewers_amount = serializers.IntegerField(
read_only=True, source="reviewers.count", label=_("Reviewers amount")
)
rules = MethodSerializer(label=_('Rule')) rules = MethodSerializer(label=_('Rule'))
class Meta: class Meta(BaserUserACLSerializer.Meta):
model = LoginACL model = LoginACL
fields_mini = ["id", "name"] fields = BaserUserACLSerializer.Meta.fields + ['rules', ]
fields_small = fields_mini + [
"priority", "user", "rules", "action",
"is_active", "date_created", "date_updated",
"comment", "created_by",
]
fields_fk = ["user"]
fields_m2m = ["reviewers", "reviewers_amount"]
fields = fields_small + fields_fk + fields_m2m
extra_kwargs = {
"priority": {"default": 50},
"is_active": {"default": True},
}
def get_rules_serializer(self): def get_rules_serializer(self):
return RuleSerializer() return RuleSerializer()

View File

@ -2,7 +2,7 @@ from django.utils.translation import gettext_lazy as _
from common.serializers import MethodSerializer from common.serializers import MethodSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaseUserAssetAccountACLSerializerMixin as BaseSerializer from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
from .rules import RuleSerializer from .rules import RuleSerializer
from ..models import LoginAssetACL from ..models import LoginAssetACL

View File

@ -10,6 +10,7 @@ router.register(r'login-acls', api.LoginACLViewSet, 'login-acl')
router.register(r'login-asset-acls', api.LoginAssetACLViewSet, 'login-asset-acl') router.register(r'login-asset-acls', api.LoginAssetACLViewSet, 'login-asset-acl')
router.register(r'command-filter-acls', api.CommandFilterACLViewSet, 'command-filter-acl') router.register(r'command-filter-acls', api.CommandFilterACLViewSet, 'command-filter-acl')
router.register(r'command-groups', api.CommandGroupViewSet, 'command-group') router.register(r'command-groups', api.CommandGroupViewSet, 'command-group')
router.register(r'connect-method-acls', api.ConnectMethodACLViewSet, 'connect-method-acl')
urlpatterns = [ urlpatterns = [
path('login-asset/check/', api.LoginAssetCheckAPI.as_view(), name='login-asset-check'), path('login-asset/check/', api.LoginAssetCheckAPI.as_view(), name='login-asset-check'),

View File

@ -1,21 +1,21 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import re
from django.templatetags.static import static
from collections import OrderedDict
from itertools import chain
import logging
import datetime import datetime
import uuid
from functools import wraps
import time
import ipaddress import ipaddress
import psutil import logging
import platform
import os import os
import platform
import re
import socket import socket
import time
import uuid
from collections import OrderedDict
from functools import wraps
from itertools import chain
import psutil
from django.conf import settings from django.conf import settings
from django.templatetags.static import static
UUID_PATTERN = re.compile(r'\w{8}(-\w{4}){3}-\w{12}') UUID_PATTERN = re.compile(r'\w{8}(-\w{4}){3}-\w{12}')
ipip_db = None ipip_db = None
@ -76,6 +76,7 @@ def setattr_bulk(seq, key, value):
def set_attr(obj): def set_attr(obj):
setattr(obj, key, value) setattr(obj, key, value)
return obj return obj
return map(set_attr, seq) return map(set_attr, seq)
@ -97,12 +98,12 @@ def capacity_convert(size, expect='auto', rate=1000):
rate_mapping = ( rate_mapping = (
('K', rate), ('K', rate),
('KB', rate), ('KB', rate),
('M', rate**2), ('M', rate ** 2),
('MB', rate**2), ('MB', rate ** 2),
('G', rate**3), ('G', rate ** 3),
('GB', rate**3), ('GB', rate ** 3),
('T', rate**4), ('T', rate ** 4),
('TB', rate**4), ('TB', rate ** 4),
) )
rate_mapping = OrderedDict(rate_mapping) rate_mapping = OrderedDict(rate_mapping)
@ -117,7 +118,7 @@ def capacity_convert(size, expect='auto', rate=1000):
if expect == 'auto': if expect == 'auto':
for unit, rate_ in rate_mapping.items(): for unit, rate_ in rate_mapping.items():
if rate > std_size/rate_ >= 1 or unit == "T": if rate > std_size / rate_ >= 1 or unit == "T":
expect = unit expect = unit
break break
@ -195,6 +196,7 @@ def with_cache(func):
res = func(*args, **kwargs) res = func(*args, **kwargs)
cache[key] = res cache[key] = res
return res return res
return wrapper return wrapper
@ -216,6 +218,7 @@ def timeit(func):
msg = "End call {}, using: {:.1f}ms".format(name, using) msg = "End call {}, using: {:.1f}ms".format(name, using)
logger.debug(msg) logger.debug(msg)
return result return result
return wrapper return wrapper
@ -310,7 +313,7 @@ class Time:
def print(self): def print(self):
last, *timestamps = self._timestamps last, *timestamps = self._timestamps
for timestamp, msg in zip(timestamps, self._msgs): for timestamp, msg in zip(timestamps, self._msgs):
logger.debug(f'TIME_IT: {msg} {timestamp-last}') logger.debug(f'TIME_IT: {msg} {timestamp - last}')
last = timestamp last = timestamp
@ -367,7 +370,7 @@ def pretty_string(data, max_length=128, ellipsis_str='...'):
def group_by_count(it, count): def group_by_count(it, count):
return [it[i:i+count] for i in range(0, len(it), count)] return [it[i:i + count] for i in range(0, len(it), count)]
def test_ip_connectivity(host, port, timeout=0.5): def test_ip_connectivity(host, port, timeout=0.5):
@ -395,3 +398,17 @@ def static_or_direct(logo_path):
def make_dirs(name, mode=0o755, exist_ok=False): def make_dirs(name, mode=0o755, exist_ok=False):
""" 默认权限设置为 0o755 """ """ 默认权限设置为 0o755 """
return os.makedirs(name, mode=mode, exist_ok=exist_ok) return os.makedirs(name, mode=mode, exist_ok=exist_ok)
def distinct(seq, key=None):
if key is None:
# 如果未提供关键字参数,则默认使用元素本身作为比较键
key = lambda x: x
seen = set()
result = []
for item in seq:
k = key(item)
if k not in seen:
seen.add(k)
result.append(item)
return result

View File

@ -148,6 +148,8 @@ only_system_permissions = (
('orgs', 'organization', 'view', 'rootorg'), ('orgs', 'organization', 'view', 'rootorg'),
('terminal', 'applet', '*', '*'), ('terminal', 'applet', '*', '*'),
('terminal', 'applethost', '*', '*'), ('terminal', 'applethost', '*', '*'),
('acls', 'loginacl', '*', '*'),
('acls', 'connectmethodacl', '*', '*')
) )
only_org_permissions = ( only_org_permissions = (

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import itertools
from rest_framework import generics from rest_framework import generics
from rest_framework.views import Response from rest_framework.views import Response
from common.permissions import IsValidUser from common.permissions import IsValidUser
from common.utils import get_request_os from common.utils import get_request_os, is_true, distinct
from terminal import serializers from terminal import serializers
from terminal.connect_methods import ConnectMethodUtil from terminal.connect_methods import ConnectMethodUtil
@ -16,9 +17,29 @@ class ConnectMethodListApi(generics.ListAPIView):
serializer_class = serializers.ConnectMethodSerializer serializer_class = serializers.ConnectMethodSerializer
permission_classes = [IsValidUser] permission_classes = [IsValidUser]
def filter_user_connect_methods(self, d):
from acls.models import ConnectMethodACL
# 这里要根据用户来了,受 acl 影响
acls = ConnectMethodACL.get_user_acls(self.request.user)
disabled_connect_methods = acls.values_list('connect_methods', flat=True)
disabled_connect_methods = set(itertools.chain.from_iterable(disabled_connect_methods))
new_queryset = {}
for protocol, methods in d.items():
new_queryset[protocol] = [x for x in methods if x['value'] not in disabled_connect_methods]
return new_queryset
def get_queryset(self): def get_queryset(self):
os = get_request_os(self.request) os = self.request.query_params.get('os') or get_request_os(self.request)
return ConnectMethodUtil.get_filtered_protocols_connect_methods(os) queryset = ConnectMethodUtil.get_filtered_protocols_connect_methods(os)
flat = self.request.query_params.get('flat')
# 先这么处理, 这里不用过滤包含的事所有
if is_true(flat):
queryset = itertools.chain.from_iterable(queryset.values())
queryset = distinct(queryset, key=lambda x: x['value'])
else:
queryset = self.filter_queryset(queryset)
return queryset
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
queryset = self.get_queryset() queryset = self.get_queryset()

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import itertools
from collections import defaultdict from collections import defaultdict
from django.conf import settings from django.conf import settings
@ -51,11 +52,7 @@ class NativeClient(TextChoices):
xshell = 'xshell', 'Xshell' xshell = 'xshell', 'Xshell'
# Magnus # Magnus
mysql = 'db_client_mysql', _('DB Client') db_client = 'db_client', _('DB Client')
psql = 'db_client_psql', _('DB Client')
sqlplus = 'db_client_sqlplus', _('DB Client')
redis = 'db_client_redis', _('DB Client')
mongodb = 'db_client_mongodb', _('DB Client')
# Razor # Razor
mstsc = 'mstsc', 'Remote Desktop' mstsc = 'mstsc', 'Remote Desktop'
@ -70,12 +67,12 @@ class NativeClient(TextChoices):
'windows': [cls.putty], 'windows': [cls.putty],
}, },
Protocol.rdp: [cls.mstsc], Protocol.rdp: [cls.mstsc],
Protocol.mysql: [cls.mysql], Protocol.mysql: [cls.db_client],
Protocol.mariadb: [cls.mysql], Protocol.mariadb: [cls.db_client],
Protocol.oracle: [cls.sqlplus], Protocol.oracle: [cls.db_client],
Protocol.postgresql: [cls.psql], Protocol.postgresql: [cls.db_client],
Protocol.redis: [cls.redis], Protocol.redis: [cls.db_client],
Protocol.mongodb: [cls.mongodb], Protocol.mongodb: [cls.db_client],
} }
return clients return clients
@ -83,6 +80,9 @@ class NativeClient(TextChoices):
def get_target_protocol(cls, name, os): def get_target_protocol(cls, name, os):
for protocol, clients in cls.get_native_clients().items(): for protocol, clients in cls.get_native_clients().items():
if isinstance(clients, dict): if isinstance(clients, dict):
if os == 'all':
clients = list(itertools.chain(*clients.values()))
else:
clients = clients.get(os) or clients.get('default') clients = clients.get(os) or clients.get('default')
if name in clients: if name in clients:
return protocol return protocol
@ -99,6 +99,9 @@ class NativeClient(TextChoices):
for protocol, _clients in clients_map.items(): for protocol, _clients in clients_map.items():
if isinstance(_clients, dict): if isinstance(_clients, dict):
if os == 'all':
_clients = list(itertools.chain(*_clients.values()))
else:
_clients = _clients.get(os, _clients['default']) _clients = _clients.get(os, _clients['default'])
for client in _clients: for client in _clients:
if not settings.XPACK_ENABLED and client in cls.xpack_methods(): if not settings.XPACK_ENABLED and client in cls.xpack_methods():
@ -245,11 +248,10 @@ class ConnectMethodUtil:
if not getattr(settings, 'TERMINAL_KOKO_SSH_ENABLED'): if not getattr(settings, 'TERMINAL_KOKO_SSH_ENABLED'):
protocol = Protocol.ssh protocol = Protocol.ssh
methods[protocol] = [m for m in methods[protocol] if m['type'] != 'native'] methods[protocol] = [m for m in methods[protocol] if m['type'] != 'native']
return methods return methods
@classmethod @classmethod
def get_protocols_connect_methods(cls, os): def get_protocols_connect_methods(cls, os='windows'):
if cls._all_methods.get('os'): if cls._all_methods.get('os'):
return cls._all_methods['os'] return cls._all_methods['os']
@ -264,7 +266,7 @@ class ConnectMethodUtil:
for protocol in support: for protocol in support:
# Web 方式 # Web 方式
methods[protocol.value].extend([ methods[str(protocol)].extend([
{ {
'component': component.value, 'component': component.value,
'type': 'web', 'type': 'web',
@ -286,7 +288,7 @@ class ConnectMethodUtil:
if component == TerminalType.koko and protocol.value != Protocol.ssh: if component == TerminalType.koko and protocol.value != Protocol.ssh:
# koko 仅支持 ssh 的 native 方式,其他数据库的 native 方式不提供 # koko 仅支持 ssh 的 native 方式,其他数据库的 native 方式不提供
continue continue
methods[protocol.value].extend([ methods[str(protocol)].extend([
{ {
'component': component.value, 'component': component.value,
'type': 'native', 'type': 'native',