From dbee3ed30d96d3012727421eeb1ad4c4d136bc72 Mon Sep 17 00:00:00 2001 From: ibuler Date: Wed, 7 Dec 2022 15:09:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20connect=20token=20=E6=B7=BB=E5=8A=A0=20?= =?UTF-8?q?Rdp=20options?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../signal_handlers/node_assets_mapping.py | 6 +- apps/authentication/api/connection_token.py | 72 +++-- .../authentication/models/connection_token.py | 78 +++++- .../serializers/connect_token_secret.py | 54 ++-- .../serializers/connection_token.py | 1 - apps/jumpserver/settings/base.py | 11 +- apps/notifications/signal_handlers.py | 27 +- apps/orgs/signal_handlers/common.py | 23 +- apps/settings/signal_handlers.py | 6 +- apps/terminal/api/component/__init__.py | 7 +- .../terminal/api/component/connect_methods.py | 25 ++ apps/terminal/api/component/terminal.py | 17 +- apps/terminal/connect_methods.py | 255 ++++++++++++++++++ apps/terminal/const.py | 229 ---------------- apps/terminal/models/applet/applet.py | 42 ++- apps/terminal/signal_handlers.py | 30 ++- 16 files changed, 535 insertions(+), 348 deletions(-) create mode 100644 apps/terminal/api/component/connect_methods.py create mode 100644 apps/terminal/connect_methods.py diff --git a/apps/assets/signal_handlers/node_assets_mapping.py b/apps/assets/signal_handlers/node_assets_mapping.py index b242f3be8..27640fc76 100644 --- a/apps/assets/signal_handlers/node_assets_mapping.py +++ b/apps/assets/signal_handlers/node_assets_mapping.py @@ -20,13 +20,9 @@ logger = get_logger(__file__) # ------------------------------------ -def get_node_assets_mapping_for_memory_pub_sub(): - return RedisPubSub('fm.node_all_asset_ids_memory_mapping') - - class NodeAssetsMappingForMemoryPubSub(LazyObject): def _setup(self): - self._wrapped = get_node_assets_mapping_for_memory_pub_sub() + self._wrapped = RedisPubSub('fm.node_all_asset_ids_memory_mapping') node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub() diff --git a/apps/authentication/api/connection_token.py b/apps/authentication/api/connection_token.py index 97ef3b5ff..60528066e 100644 --- a/apps/authentication/api/connection_token.py +++ b/apps/authentication/api/connection_token.py @@ -19,12 +19,12 @@ from common.utils import random_string from common.utils.django import get_request_os from orgs.mixins.api import RootOrgViewMixin from perms.models import ActionChoices -from terminal.const import NativeClient, TerminalType +from terminal.connect_methods import NativeClient, ConnectMethodUtil from terminal.models import EndpointRule, Applet from ..models import ConnectionToken from ..serializers import ( ConnectionTokenSerializer, ConnectionTokenSecretSerializer, - SuperConnectionTokenSerializer, + SuperConnectionTokenSerializer, ConnectTokenAppletOptionSerializer ) __all__ = ['ConnectionTokenViewSet', 'SuperConnectionTokenViewSet'] @@ -115,7 +115,8 @@ class RDPFileClientProtocolURLMixin: rdp_options['audiomode:i'] = self.parse_env_bool('JUMPSERVER_DISABLE_AUDIO', 'false', '2', '0') # 设置远程应用 - self.set_applet_info(token, rdp_options) + remote_app_options = token.get_remote_app_option() + rdp_options.update(remote_app_options) # 文件名 name = token.asset.name @@ -145,7 +146,7 @@ class RDPFileClientProtocolURLMixin: _os = get_request_os(self.request) connect_method_name = token.connect_method - connect_method_dict = TerminalType.get_connect_method( + connect_method_dict = ConnectMethodUtil.get_connect_method( token.connect_method, token.protocol, _os ) if connect_method_dict is None: @@ -227,38 +228,16 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView search_fields = filterset_fields serializer_classes = { 'default': ConnectionTokenSerializer, - 'get_secret_detail': ConnectionTokenSecretSerializer, } rbac_perms = { 'list': 'authentication.view_connectiontoken', 'retrieve': 'authentication.view_connectiontoken', 'create': 'authentication.add_connectiontoken', 'expire': 'authentication.add_connectiontoken', - 'get_secret_detail': 'authentication.view_connectiontokensecret', 'get_rdp_file': 'authentication.add_connectiontoken', 'get_client_protocol_url': 'authentication.add_connectiontoken', } - @action(methods=['POST'], detail=False, url_path='secret') - def get_secret_detail(self, request, *args, **kwargs): - """ 非常重要的 api, 在逻辑层再判断一下 rbac 权限, 双重保险 """ - rbac_perm = 'authentication.view_connectiontokensecret' - if not request.user.has_perm(rbac_perm): - raise PermissionDenied('Not allow to view secret') - - token_id = request.data.get('id') or '' - token = get_object_or_404(ConnectionToken, pk=token_id) - if token.is_expired: - raise ValidationError({'id': 'Token is expired'}) - - token.is_valid() - serializer = self.get_serializer(instance=token) - expire_now = request.data.get('expire_now', True) - if expire_now: - token.expire() - - return Response(serializer.data, status=status.HTTP_200_OK) - def get_queryset(self): queryset = ConnectionToken.objects \ .filter(user=self.request.user) \ @@ -305,10 +284,14 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView class SuperConnectionTokenViewSet(ConnectionTokenViewSet): serializer_classes = { 'default': SuperConnectionTokenSerializer, + 'get_secret_detail': ConnectionTokenSecretSerializer, } rbac_perms = { 'create': 'authentication.add_superconnectiontoken', - 'renewal': 'authentication.add_superconnectiontoken' + 'renewal': 'authentication.add_superconnectiontoken', + 'get_secret_detail': 'authentication.view_connectiontokensecret', + 'get_applet_info': 'authentication.view_superconnectiontoken', + 'release_applet_account': 'authentication.view_superconnectiontoken', } def get_queryset(self): @@ -332,3 +315,38 @@ class SuperConnectionTokenViewSet(ConnectionTokenViewSet): 'msg': f'Token is renewed, date expired: {date_expired}' } return Response(data=data, status=status.HTTP_200_OK) + + @action(methods=['POST'], detail=False, url_path='secret') + def get_secret_detail(self, request, *args, **kwargs): + """ 非常重要的 api, 在逻辑层再判断一下 rbac 权限, 双重保险 """ + rbac_perm = 'authentication.view_connectiontokensecret' + if not request.user.has_perm(rbac_perm): + raise PermissionDenied('Not allow to view secret') + + token_id = request.data.get('id') or '' + token = get_object_or_404(ConnectionToken, pk=token_id) + if token.is_expired: + raise ValidationError({'id': 'Token is expired'}) + + token.is_valid() + serializer = self.get_serializer(instance=token) + expire_now = request.data.get('expire_now', True) + if expire_now: + token.expire() + return Response(serializer.data, status=status.HTTP_200_OK) + + @action(methods=['POST'], detail=False, url_path='applet-option') + def get_applet_info(self, *args, **kwargs): + token_id = self.request.data.get('id') + token = get_object_or_404(ConnectionToken, pk=token_id) + if token.is_expired: + return Response({'error': 'Token expired'}, status=status.HTTP_400_BAD_REQUEST) + data = token.get_applet_option() + serializer = ConnectTokenAppletOptionSerializer(data) + return Response(serializer.data) + + @action(methods=['DELETE', 'POST'], detail=False, url_path='applet-account/release') + def release_applet_account(self, *args, **kwargs): + account_id = self.request.data.get('id') + msg = ConnectionToken.release_applet_account(account_id) + return Response({'msg': msg}) diff --git a/apps/authentication/models/connection_token.py b/apps/authentication/models/connection_token.py index 4a23d18b8..421ec0969 100644 --- a/apps/authentication/models/connection_token.py +++ b/apps/authentication/models/connection_token.py @@ -1,6 +1,9 @@ +import base64 +import json from datetime import timedelta from django.conf import settings +from django.core.cache import cache from django.db import models from django.utils import timezone from django.utils.translation import ugettext_lazy as _ @@ -9,9 +12,10 @@ from rest_framework.exceptions import PermissionDenied from assets.const import Protocol from common.db.fields import EncryptCharField from common.db.models import JMSBaseModel -from common.utils import lazyproperty, pretty_string +from common.utils import lazyproperty, pretty_string, bulk_get from common.utils.timezone import as_current_tz from orgs.mixins.models import OrgModelMixin +from terminal.models import Applet def date_expired_default(): @@ -101,6 +105,9 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel): error = _('No account') raise PermissionDenied(error) + if timezone.now() - self.date_created < timedelta(seconds=60): + return True, None + if not self.permed_account or not self.permed_account.actions: msg = 'user `{}` not has asset `{}` permission for login `{}`'.format( self.user, self.asset, self.account @@ -115,6 +122,75 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel): def platform(self): return self.asset.platform + @lazyproperty + def connect_method_object(self): + from common.utils import get_request_os + from jumpserver.utils import get_current_request + from terminal.connect_methods import ConnectMethodUtil + + request = get_current_request() + os = get_request_os(request) if request else 'windows' + method = ConnectMethodUtil.get_connect_method( + self.connect_method, protocol=self.protocol, os=os + ) + return method + + def get_remote_app_option(self): + cmdline = { + 'app_name': self.connect_method, + 'user_id': str(self.user.id), + 'asset_id': str(self.asset.id), + 'token_id': str(self.id) + } + cmdline_b64 = base64.b64encode(json.dumps(cmdline).encode()).decode() + app = '||tinker' + options = { + 'remoteapplicationmode:i': '1', + 'remoteapplicationprogram:s': app, + 'remoteapplicationname:s': app, + 'alternate shell:s': app, + 'remoteapplicationcmdline:s': cmdline_b64, + } + return options + + def get_applet_option(self): + method = self.connect_method_object + if not method or method.get('type') != 'applet' or method.get('disabled', False): + return None + + applet = Applet.objects.filter(name=method.get('value')).first() + if not applet: + return None + + host_account = applet.select_host_account() + if not host_account: + return None + + host, account, lock_key, ttl = bulk_get(host_account, ('host', 'account', 'lock_key', 'ttl')) + gateway = host.gateway.select_gateway() if host.domain else None + + data = { + 'id': account.id, + 'applet': applet, + 'host': host, + 'gateway': gateway, + 'account': account, + 'remote_app_option': self.get_remote_app_option() + } + token_account_relate_key = f'token_account_relate_{account.id}' + cache.set(token_account_relate_key, lock_key, ttl) + return data + + @staticmethod + def release_applet_account(account_id): + token_account_relate_key = f'token_account_relate_{account_id}' + lock_key = cache.get(token_account_relate_key) + if lock_key: + cache.delete(lock_key) + cache.delete(token_account_relate_key) + return 'released' + return 'not found or expired' + @lazyproperty def account_object(self): from assets.models import Account diff --git a/apps/authentication/serializers/connect_token_secret.py b/apps/authentication/serializers/connect_token_secret.py index f24f3e9c6..e4fd20be0 100644 --- a/apps/authentication/serializers/connect_token_secret.py +++ b/apps/authentication/serializers/connect_token_secret.py @@ -1,19 +1,17 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers -from common.drf.fields import ObjectRelatedField from acls.models import CommandGroup, CommandFilterACL from assets.models import Asset, Account, Platform, Gateway, Domain from assets.serializers import PlatformSerializer, AssetProtocolsSerializer -from users.models import User -from perms.serializers.permission import ActionChoicesField +from common.drf.fields import ObjectRelatedField from orgs.mixins.serializers import OrgResourceModelSerializerMixin - +from perms.serializers.permission import ActionChoicesField +from users.models import User from ..models import ConnectionToken - __all__ = [ - 'ConnectionTokenSecretSerializer', + 'ConnectionTokenSecretSerializer', 'ConnectTokenAppletOptionSerializer' ] @@ -96,6 +94,24 @@ class _ConnectionTokenPlatformSerializer(PlatformSerializer): return names +class _ConnectionTokenConnectMethodSerializer(serializers.Serializer): + name = serializers.CharField(label=_('Name')) + protocol = serializers.CharField(label=_('Protocol')) + os = serializers.CharField(label=_('OS')) + is_builtin = serializers.BooleanField(label=_('Is builtin')) + is_active = serializers.BooleanField(label=_('Is active')) + platform = _ConnectionTokenPlatformSerializer(label=_('Platform')) + action = ActionChoicesField(label=_('Action')) + options = serializers.JSONField(label=_('Options')) + + +class _ConnectTokenConnectMethodSerializer(serializers.Serializer): + label = serializers.CharField(label=_('Label')) + value = serializers.CharField(label=_('Value')) + type = serializers.CharField(label=_('Type')) + component = serializers.CharField(label=_('Component')) + + class ConnectionTokenSecretSerializer(OrgResourceModelSerializerMixin): user = _ConnectionTokenUserSerializer(read_only=True) asset = _ConnectionTokenAssetSerializer(read_only=True) @@ -104,30 +120,28 @@ class ConnectionTokenSecretSerializer(OrgResourceModelSerializerMixin): platform = _ConnectionTokenPlatformSerializer(read_only=True) domain = ObjectRelatedField(queryset=Domain.objects, required=False, label=_('Domain')) command_filter_acls = _ConnectionTokenCommandFilterACLSerializer(read_only=True, many=True) + expire_now = serializers.BooleanField(label=_('Expired now'), write_only=True, default=True) + connect_method = _ConnectTokenConnectMethodSerializer(read_only=True, source='connect_method_object') actions = ActionChoicesField() expire_at = serializers.IntegerField() - expire_now = serializers.BooleanField(label=_('Expired now'), write_only=True, default=True) - connect_method = serializers.SerializerMethodField(label=_('Connect method')) class Meta: model = ConnectionToken fields = [ 'id', 'value', 'user', 'asset', 'account', 'platform', 'command_filter_acls', 'protocol', - 'domain', 'gateway', 'actions', 'expire_at', 'expire_now', - 'connect_method' + 'domain', 'gateway', 'actions', 'expire_at', + 'expire_now', 'connect_method', ] extra_kwargs = { 'value': {'read_only': True}, } - def get_connect_method(self, obj): - from terminal.const import TerminalType - from common.utils import get_request_os - request = self.context.get('request') - if request: - os = get_request_os(request) - else: - os = 'windows' - method = TerminalType.get_connect_method(obj.connect_method, protocol=obj.protocol, os=os) - return method + +class ConnectTokenAppletOptionSerializer(serializers.Serializer): + id = serializers.CharField(label=_('ID')) + applet = ObjectRelatedField(read_only=True) + host = _ConnectionTokenAssetSerializer(read_only=True) + account = _ConnectionTokenAccountSerializer(read_only=True) + gateway = _ConnectionTokenGatewaySerializer(read_only=True) + remote_app_option = serializers.JSONField(read_only=True) diff --git a/apps/authentication/serializers/connection_token.py b/apps/authentication/serializers/connection_token.py index 2b5b156e8..e45037853 100644 --- a/apps/authentication/serializers/connection_token.py +++ b/apps/authentication/serializers/connection_token.py @@ -2,7 +2,6 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers from orgs.mixins.serializers import OrgResourceModelSerializerMixin - from ..models import ConnectionToken __all__ = [ diff --git a/apps/jumpserver/settings/base.py b/apps/jumpserver/settings/base.py index aaedd2ddc..3b19bb16e 100644 --- a/apps/jumpserver/settings/base.py +++ b/apps/jumpserver/settings/base.py @@ -3,7 +3,6 @@ import platform from redis.sentinel import SentinelManagedSSLConnection - if platform.system() == 'Darwin' and platform.machine() == 'arm64': import pymysql @@ -308,17 +307,22 @@ else: REDIS_SENTINEL_SOCKET_TIMEOUT = None # Cache config + REDIS_OPTIONS = { "REDIS_CLIENT_KWARGS": { "health_check_interval": 30 }, "CONNECTION_POOL_KWARGS": { + 'max_connections': 100, + } +} +if REDIS_USE_SSL: + REDIS_OPTIONS['CONNECTION_POOL_KWARGS'].update({ 'ssl_cert_reqs': REDIS_SSL_REQUIRED, "ssl_keyfile": REDIS_SSL_KEY, "ssl_certfile": REDIS_SSL_CERT, "ssl_ca_certs": REDIS_SSL_CA - } if REDIS_USE_SSL else {} -} + }) if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS: REDIS_LOCATION_NO_DB = "%(protocol)s://%(service_name)s/{}" % { @@ -348,7 +352,6 @@ else: 'host': CONFIG.REDIS_HOST, 'port': CONFIG.REDIS_PORT, } - REDIS_CACHE_DEFAULT = { 'BACKEND': 'redis_lock.django_cache.RedisCache', 'LOCATION': REDIS_LOCATION_NO_DB.format(CONFIG.REDIS_DB_CACHE), diff --git a/apps/notifications/signal_handlers.py b/apps/notifications/signal_handlers.py index aaf3480bf..c0b1f1c1c 100644 --- a/apps/notifications/signal_handlers.py +++ b/apps/notifications/signal_handlers.py @@ -1,32 +1,26 @@ -import json -from importlib import import_module import inspect +from importlib import import_module -from django.utils.functional import LazyObject -from django.db.models.signals import post_save -from django.db.models.signals import post_migrate -from django.dispatch import receiver from django.apps import AppConfig +from django.db.models.signals import post_migrate +from django.db.models.signals import post_save +from django.dispatch import receiver +from django.utils.functional import LazyObject +from common.decorator import on_transaction_commit +from common.utils import get_logger +from common.utils.connection import RedisPubSub from notifications.backends import BACKEND from users.models import User -from common.utils.connection import RedisPubSub -from common.utils import get_logger -from common.decorator import on_transaction_commit from .models import SiteMessage, SystemMsgSubscription, UserMsgSubscription from .notifications import SystemMessage - logger = get_logger(__name__) -def new_site_msg_pub_sub(): - return RedisPubSub('notifications.SiteMessageCome') - - class NewSiteMsgSubPub(LazyObject): def _setup(self): - self._wrapped = new_site_msg_pub_sub() + self._wrapped = RedisPubSub('notifications.SiteMessageCome') new_site_msg_chan = NewSiteMsgSubPub() @@ -78,7 +72,8 @@ def create_system_messages(app_config: AppConfig, **kwargs): sub, created = SystemMsgSubscription.objects.get_or_create(message_type=message_type) if created: obj.post_insert_to_db(sub) - logger.info(f'Create SystemMsgSubscription: package={app_config.module.__package__} type={message_type}') + logger.info( + f'Create SystemMsgSubscription: package={app_config.module.__package__} type={message_type}') except ModuleNotFoundError: pass diff --git a/apps/orgs/signal_handlers/common.py b/apps/orgs/signal_handlers/common.py index b32cba2cd..4935eeec9 100644 --- a/apps/orgs/signal_handlers/common.py +++ b/apps/orgs/signal_handlers/common.py @@ -3,35 +3,30 @@ from collections import defaultdict from functools import partial -import django.db.utils -from django.dispatch import receiver from django.conf import settings -from django.db.utils import ProgrammingError, OperationalError -from django.utils.functional import LazyObject from django.db.models.signals import post_save, pre_delete, m2m_changed +from django.db.utils import ProgrammingError, OperationalError +from django.dispatch import receiver +from django.utils.functional import LazyObject -from orgs.utils import tmp_to_org, set_to_default_org -from orgs.models import Organization -from orgs.hands import set_current_org, Node, get_current_org -from perms.models import AssetPermission -from users.models import UserGroup, User from common.const.signals import PRE_REMOVE, POST_REMOVE from common.decorator import on_transaction_commit from common.signals import django_ready from common.utils import get_logger from common.utils.connection import RedisPubSub +from orgs.hands import set_current_org, Node, get_current_org +from orgs.models import Organization +from orgs.utils import tmp_to_org, set_to_default_org +from perms.models import AssetPermission +from users.models import UserGroup, User from users.signals import post_user_leave_org logger = get_logger(__file__) -def get_orgs_mapping_for_memory_pub_sub(): - return RedisPubSub('fm.orgs_mapping') - - class OrgsMappingForMemoryPubSub(LazyObject): def _setup(self): - self._wrapped = get_orgs_mapping_for_memory_pub_sub() + self._wrapped = RedisPubSub('fm.orgs_mapping') orgs_mapping_for_memory_pub_sub = OrgsMappingForMemoryPubSub() diff --git a/apps/settings/signal_handlers.py b/apps/settings/signal_handlers.py index c963488f1..b04ae4f5e 100644 --- a/apps/settings/signal_handlers.py +++ b/apps/settings/signal_handlers.py @@ -18,13 +18,9 @@ from .models import Setting logger = get_logger(__file__) -def get_settings_pub_sub(): - return RedisPubSub('settings') - - class SettingSubPub(LazyObject): def _setup(self): - self._wrapped = get_settings_pub_sub() + self._wrapped = RedisPubSub('settings') setting_pub_sub = SettingSubPub() diff --git a/apps/terminal/api/component/__init__.py b/apps/terminal/api/component/__init__.py index afefe0c18..56432ca42 100644 --- a/apps/terminal/api/component/__init__.py +++ b/apps/terminal/api/component/__init__.py @@ -1,4 +1,5 @@ -from .terminal import * -from .storage import * -from .status import * +from .connect_methods import * from .endpoint import * +from .status import * +from .storage import * +from .terminal import * diff --git a/apps/terminal/api/component/connect_methods.py b/apps/terminal/api/component/connect_methods.py new file mode 100644 index 000000000..a284159d3 --- /dev/null +++ b/apps/terminal/api/component/connect_methods.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# + +from rest_framework import generics +from rest_framework.views import Response + +from common.permissions import IsValidUser +from common.utils import get_request_os +from terminal import serializers +from terminal.connect_methods import ConnectMethodUtil + +__all__ = ['ConnectMethodListApi'] + + +class ConnectMethodListApi(generics.ListAPIView): + serializer_class = serializers.ConnectMethodSerializer + permission_classes = [IsValidUser] + + def get_queryset(self): + os = get_request_os(self.request) + return ConnectMethodUtil.get_protocols_connect_methods(os) + + def list(self, request, *args, **kwargs): + queryset = self.get_queryset() + return Response(queryset) diff --git a/apps/terminal/api/component/terminal.py b/apps/terminal/api/component/terminal.py index df14296f5..d32adf02b 100644 --- a/apps/terminal/api/component/terminal.py +++ b/apps/terminal/api/component/terminal.py @@ -10,16 +10,13 @@ from rest_framework.views import APIView, Response from common.drf.api import JMSBulkModelViewSet from common.exceptions import JMSException -from common.permissions import IsValidUser from common.permissions import WithBootstrapToken -from common.utils import get_request_os from terminal import serializers -from terminal.const import TerminalType from terminal.models import Terminal __all__ = [ 'TerminalViewSet', 'TerminalConfig', - 'TerminalRegistrationApi', 'ConnectMethodListApi' + 'TerminalRegistrationApi', ] logger = logging.getLogger(__file__) @@ -72,15 +69,3 @@ class TerminalRegistrationApi(generics.CreateAPIView): return Response(data=data, status=status.HTTP_400_BAD_REQUEST) return super().create(request, *args, **kwargs) - -class ConnectMethodListApi(generics.ListAPIView): - serializer_class = serializers.ConnectMethodSerializer - permission_classes = [IsValidUser] - - def get_queryset(self): - os = get_request_os(self.request) - return TerminalType.get_protocols_connect_methods(os) - - def list(self, request, *args, **kwargs): - queryset = self.get_queryset() - return Response(queryset) diff --git a/apps/terminal/connect_methods.py b/apps/terminal/connect_methods.py new file mode 100644 index 000000000..ef0b2e943 --- /dev/null +++ b/apps/terminal/connect_methods.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +# +from collections import defaultdict + +from django.db.models import TextChoices +from django.utils.translation import ugettext_lazy as _ + +from assets.const import Protocol +from .const import TerminalType + + +class WebMethod(TextChoices): + web_gui = 'web_gui', 'Web GUI' + web_cli = 'web_cli', 'Web CLI' + web_sftp = 'web_sftp', 'Web SFTP' + + @classmethod + def get_methods(cls): + return { + Protocol.ssh: [cls.web_cli, cls.web_sftp], + Protocol.telnet: [cls.web_cli], + Protocol.rdp: [cls.web_gui], + Protocol.vnc: [cls.web_gui], + + Protocol.mysql: [cls.web_cli, cls.web_gui], + Protocol.mariadb: [cls.web_cli, cls.web_gui], + Protocol.oracle: [cls.web_cli, cls.web_gui], + Protocol.postgresql: [cls.web_cli, cls.web_gui], + Protocol.sqlserver: [cls.web_cli, cls.web_gui], + Protocol.redis: [cls.web_cli], + Protocol.mongodb: [cls.web_cli], + + Protocol.k8s: [cls.web_gui], + Protocol.http: [] + } + + +class NativeClient(TextChoices): + # Koko + ssh = 'ssh', 'SSH' + putty = 'putty', 'PuTTY' + xshell = 'xshell', 'Xshell' + + # Magnus + mysql = 'db_client_mysql', _('DB Client') + psql = 'db_client_psql', _('DB Client') + sqlplus = 'db_client_sqlplus', _('DB Client') + redis = 'db_client_redis', _('DB Client') + mongodb = 'db_client_mongodb', _('DB Client') + + # Razor + mstsc = 'mstsc', 'Remote Desktop' + + @classmethod + def get_native_clients(cls): + # native client 关注的是 endpoint 的 protocol, + # 比如 telnet mysql, koko 都支持,到那时暴露的是 ssh 协议 + clients = { + Protocol.ssh: { + 'default': [cls.ssh], + 'windows': [cls.putty], + }, + Protocol.rdp: [cls.mstsc], + Protocol.mysql: [cls.mysql], + Protocol.oracle: [cls.sqlplus], + Protocol.postgresql: [cls.psql], + Protocol.redis: [cls.redis], + Protocol.mongodb: [cls.mongodb], + } + return clients + + @classmethod + def get_target_protocol(cls, name, os): + for protocol, clients in cls.get_native_clients().items(): + if isinstance(clients, dict): + clients = clients.get(os) or clients.get('default') + if name in clients: + return protocol + return None + + @classmethod + def get_methods(cls, os='windows'): + clients_map = cls.get_native_clients() + methods = defaultdict(list) + + for protocol, _clients in clients_map.items(): + if isinstance(_clients, dict): + _clients = _clients.get(os, _clients['default']) + for client in _clients: + methods[protocol].append({ + 'value': client.value, + 'label': client.label, + 'type': 'native', + }) + return methods + + @classmethod + def get_launch_command(cls, name, token, endpoint, os='windows'): + username = f'JMS-{token.id}' + commands = { + cls.ssh: f'ssh {username}@{endpoint.host} -p {endpoint.ssh_port}', + cls.putty: f'putty.exe -ssh {username}@{endpoint.host} -P {endpoint.ssh_port}', + cls.xshell: f'xshell.exe -url ssh://{username}:{token.value}@{endpoint.host}:{endpoint.ssh_port}', + # cls.mysql: 'mysql -h {hostname} -P {port} -u {username} -p', + # cls.psql: { + # 'default': 'psql -h {hostname} -p {port} -U {username} -W', + # 'windows': 'psql /h {hostname} /p {port} /U {username} -W', + # }, + # cls.sqlplus: 'sqlplus {username}/{password}@{hostname}:{port}', + # cls.redis: 'redis-cli -h {hostname} -p {port} -a {password}', + } + command = commands.get(name) + if isinstance(command, dict): + command = command.get(os, command.get('default')) + return command + + +class AppletMethod: + @classmethod + def get_methods(cls): + from .models import Applet, AppletHost + + methods = defaultdict(list) + has_applet_hosts = AppletHost.objects.all().exists() + + applets = Applet.objects.filter(is_active=True) + for applet in applets: + for protocol in applet.protocols: + methods[protocol].append({ + 'value': applet.name, + 'label': applet.display_name, + 'type': 'applet', + 'icon': applet.icon, + 'disabled': not applet.is_active or not has_applet_hosts, + }) + return methods + + +class ConnectMethodUtil: + _all_methods = None + + @classmethod + def protocols(cls): + protocols = { + TerminalType.koko: { + 'web_methods': [WebMethod.web_cli, WebMethod.web_sftp], + 'listen': [Protocol.ssh, Protocol.http], + 'support': [ + Protocol.ssh, Protocol.telnet, + Protocol.mysql, Protocol.postgresql, + Protocol.oracle, Protocol.sqlserver, + Protocol.mariadb, Protocol.redis, + Protocol.mongodb, Protocol.k8s, + ], + 'match': 'm2m' + }, + TerminalType.omnidb: { + 'web_methods': [WebMethod.web_gui], + 'listen': [Protocol.http], + 'support': [ + Protocol.mysql, Protocol.postgresql, Protocol.oracle, + Protocol.sqlserver, Protocol.mariadb + ], + 'match': 'm2m' + }, + TerminalType.lion: { + 'web_methods': [WebMethod.web_gui], + 'listen': [Protocol.http], + 'support': [Protocol.rdp, Protocol.vnc], + 'match': 'm2m' + }, + TerminalType.magnus: { + 'listen': [], + 'support': [ + Protocol.mysql, Protocol.postgresql, + Protocol.oracle, Protocol.mariadb + ], + 'match': 'map' + }, + TerminalType.razor: { + 'listen': [Protocol.rdp], + 'support': [Protocol.rdp], + 'match': 'map' + }, + } + return protocols + + @classmethod + def get_connect_method(cls, name, protocol, os='linux'): + methods = cls.get_protocols_connect_methods(os) + protocol_methods = methods.get(protocol, []) + for method in protocol_methods: + if method['value'] == name: + return method + return None + + @classmethod + def refresh_methods(cls): + cls._all_methods = None + + @classmethod + def get_protocols_connect_methods(cls, os): + if cls._all_methods is not None: + return cls._all_methods + + methods = defaultdict(list) + web_methods = WebMethod.get_methods() + native_methods = NativeClient.get_methods(os) + applet_methods = AppletMethod.get_methods() + + for component, component_protocol in cls.protocols().items(): + support = component_protocol['support'] + + for protocol in support: + # Web 方式 + protocol_web_methods = set(web_methods.get(protocol, [])) \ + & set(component_protocol.get('web_methods', [])) + methods[protocol.value].extend([ + { + 'component': component.value, + 'type': 'web', + 'endpoint_protocol': 'http', + 'value': method.value, + 'label': method.label, + } + for method in protocol_web_methods + ]) + + # 客户端方式 + if component_protocol['match'] == 'map': + listen = [protocol] + else: + listen = component_protocol['listen'] + + for listen_protocol in listen: + # Native method + methods[protocol.value].extend([ + { + 'component': component.value, + 'type': 'native', + 'endpoint_protocol': listen_protocol, + **method + } + for method in native_methods[listen_protocol] + ]) + + # 远程应用方式,这个只有 tinker 提供 + for protocol, applet_methods in applet_methods.items(): + for method in applet_methods: + method['listen'] = 'rdp' + method['component'] = TerminalType.tinker.value + methods[protocol].extend(applet_methods) + + cls._all_methods = methods + return methods diff --git a/apps/terminal/const.py b/apps/terminal/const.py index c5ceb3c94..9fcd4cb72 100644 --- a/apps/terminal/const.py +++ b/apps/terminal/const.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- # -from collections import defaultdict from django.db.models import TextChoices from django.utils.translation import ugettext_lazy as _ -from assets.const import Protocol - # Replay & Command Storage Choices # -------------------------------- @@ -44,128 +41,6 @@ class ComponentLoad(TextChoices): return set(dict(cls.choices).keys()) -class WebMethod(TextChoices): - web_gui = 'web_gui', 'Web GUI' - web_cli = 'web_cli', 'Web CLI' - web_sftp = 'web_sftp', 'Web SFTP' - - @classmethod - def get_methods(cls): - return { - Protocol.ssh: [cls.web_cli, cls.web_sftp], - Protocol.telnet: [cls.web_cli], - Protocol.rdp: [cls.web_gui], - Protocol.vnc: [cls.web_gui], - - Protocol.mysql: [cls.web_cli, cls.web_gui], - Protocol.mariadb: [cls.web_cli, cls.web_gui], - Protocol.oracle: [cls.web_cli, cls.web_gui], - Protocol.postgresql: [cls.web_cli, cls.web_gui], - Protocol.sqlserver: [cls.web_cli, cls.web_gui], - Protocol.redis: [cls.web_cli], - Protocol.mongodb: [cls.web_cli], - - Protocol.k8s: [cls.web_gui], - Protocol.http: [] - } - - -class NativeClient(TextChoices): - # Koko - ssh = 'ssh', 'SSH' - putty = 'putty', 'PuTTY' - xshell = 'xshell', 'Xshell' - - # Magnus - mysql = 'db_client_mysql', _('DB Client') - psql = 'db_client_psql', _('DB Client') - sqlplus = 'db_client_sqlplus', _('DB Client') - redis = 'db_client_redis', _('DB Client') - mongodb = 'db_client_mongodb', _('DB Client') - - # Razor - mstsc = 'mstsc', 'Remote Desktop' - - @classmethod - def get_native_clients(cls): - # native client 关注的是 endpoint 的 protocol, - # 比如 telnet mysql, koko 都支持,到那时暴露的是 ssh 协议 - clients = { - Protocol.ssh: { - 'default': [cls.ssh], - 'windows': [cls.putty], - }, - Protocol.rdp: [cls.mstsc], - Protocol.mysql: [cls.mysql], - Protocol.oracle: [cls.sqlplus], - Protocol.postgresql: [cls.psql], - Protocol.redis: [cls.redis], - Protocol.mongodb: [cls.mongodb], - } - return clients - - @classmethod - def get_target_protocol(cls, name, os): - for protocol, clients in cls.get_native_clients().items(): - if isinstance(clients, dict): - clients = clients.get(os) or clients.get('default') - if name in clients: - return protocol - return None - - @classmethod - def get_methods(cls, os='windows'): - clients_map = cls.get_native_clients() - methods = defaultdict(list) - - for protocol, _clients in clients_map.items(): - if isinstance(_clients, dict): - _clients = _clients.get(os, _clients['default']) - for client in _clients: - methods[protocol].append({ - 'value': client.value, - 'label': client.label, - 'type': 'native', - }) - return methods - - @classmethod - def get_launch_command(cls, name, token, endpoint, os='windows'): - username = f'JMS-{token.id}' - commands = { - cls.ssh: f'ssh {username}@{endpoint.host} -p {endpoint.ssh_port}', - cls.putty: f'putty.exe -ssh {username}@{endpoint.host} -P {endpoint.ssh_port}', - cls.xshell: f'xshell.exe -url ssh://{username}:{token.value}@{endpoint.host}:{endpoint.ssh_port}', - # cls.mysql: 'mysql -h {hostname} -P {port} -u {username} -p', - # cls.psql: { - # 'default': 'psql -h {hostname} -p {port} -U {username} -W', - # 'windows': 'psql /h {hostname} /p {port} /U {username} -W', - # }, - # cls.sqlplus: 'sqlplus {username}/{password}@{hostname}:{port}', - # cls.redis: 'redis-cli -h {hostname} -p {port} -a {password}', - } - command = commands.get(name) - if isinstance(command, dict): - command = command.get(os, command.get('default')) - return command - - -class AppletMethod: - @classmethod - def get_methods(cls): - from .models import Applet - applets = Applet.objects.all() - methods = defaultdict(list) - for applet in applets: - for protocol in applet.protocols: - methods[protocol].append({ - 'value': applet.name, - 'label': applet.display_name, - 'icon': applet.icon, - }) - return methods - - class TerminalType(TextChoices): koko = 'koko', 'KoKo' guacamole = 'guacamole', 'Guacamole' @@ -181,107 +56,3 @@ class TerminalType(TextChoices): @classmethod def types(cls): return set(dict(cls.choices).keys()) - - @classmethod - def protocols(cls): - protocols = { - cls.koko: { - 'web_methods': [WebMethod.web_cli, WebMethod.web_sftp], - 'listen': [Protocol.ssh, Protocol.http], - 'support': [ - Protocol.ssh, Protocol.telnet, - Protocol.mysql, Protocol.postgresql, - Protocol.oracle, Protocol.sqlserver, - Protocol.mariadb, Protocol.redis, - Protocol.mongodb, Protocol.k8s, - ], - 'match': 'm2m' - }, - cls.omnidb: { - 'web_methods': [WebMethod.web_gui], - 'listen': [Protocol.http], - 'support': [ - Protocol.mysql, Protocol.postgresql, Protocol.oracle, - Protocol.sqlserver, Protocol.mariadb - ], - 'match': 'm2m' - }, - cls.lion: { - 'web_methods': [WebMethod.web_gui], - 'listen': [Protocol.http], - 'support': [Protocol.rdp, Protocol.vnc], - 'match': 'm2m' - }, - cls.magnus: { - 'listen': [], - 'support': [ - Protocol.mysql, Protocol.postgresql, - Protocol.oracle, Protocol.mariadb - ], - 'match': 'map' - }, - cls.razor: { - 'listen': [Protocol.rdp], - 'support': [Protocol.rdp], - 'match': 'map' - }, - } - return protocols - - @classmethod - def get_connect_method(cls, name, protocol, os='linux'): - methods = cls.get_protocols_connect_methods(os) - protocol_methods = methods.get(protocol, []) - for method in protocol_methods: - if method['value'] == name: - return method - return None - - @classmethod - def get_protocols_connect_methods(cls, os): - methods = defaultdict(list) - web_methods = WebMethod.get_methods() - native_methods = NativeClient.get_methods(os) - applet_methods = AppletMethod.get_methods() - - for component, component_protocol in cls.protocols().items(): - support = component_protocol['support'] - - for protocol in support: - if component_protocol['match'] == 'map': - listen = [protocol] - else: - listen = component_protocol['listen'] - - for listen_protocol in listen: - # Native method - methods[protocol.value].extend([ - { - 'component': component.value, - 'type': 'native', - 'endpoint_protocol': listen_protocol, - **method - } - for method in native_methods[listen_protocol] - ]) - - protocol_web_methods = set(web_methods.get(protocol, [])) \ - & set(component_protocol.get('web_methods', [])) - methods[protocol.value].extend([ - { - 'component': component.value, - 'type': 'web', - 'endpoint_protocol': 'http', - 'value': method.value, - 'label': method.label, - } - for method in protocol_web_methods - ]) - - for protocol, applet_methods in applet_methods.items(): - for method in applet_methods: - method['type'] = 'applet' - method['listen'] = 'rdp' - method['component'] = cls.tinker.value - methods[protocol].extend(applet_methods) - return methods diff --git a/apps/terminal/models/applet/applet.py b/apps/terminal/models/applet/applet.py index cf854b036..bfe0e5e67 100644 --- a/apps/terminal/models/applet/applet.py +++ b/apps/terminal/models/applet/applet.py @@ -1,14 +1,15 @@ -import yaml import os.path +import random +import yaml from django.conf import settings +from django.core.cache import cache from django.core.files.storage import default_storage from django.db import models from django.utils.translation import gettext_lazy as _ from common.db.models import JMSBaseModel - __all__ = ['Applet', 'AppletPublication'] @@ -53,10 +54,43 @@ class Applet(JMSBaseModel): return None return os.path.join(settings.MEDIA_URL, 'applets', self.name, 'icon.png') + def select_host_account(self): + hosts = list(self.hosts.all()) + if not hosts: + return None + + host = random.choice(hosts) + using_keys = cache.keys('host_accounts_{}_*'.format(host.id)) or [] + accounts_used = cache.get_many(using_keys) + accounts = host.accounts.all().exclude(username__in=accounts_used) + + if not accounts: + accounts = host.accounts.all() + if not accounts: + return None + + account = random.choice(accounts) + ttl = 60 * 60 * 24 + lock_key = 'applet_host_accounts_{}_{}'.format(host.id, account.username) + cache.set(lock_key, account.username, ttl) + return { + 'host': host, + 'account': account, + 'lock_key': lock_key, + 'ttl': ttl + } + + @staticmethod + def release_host_and_account(host_id, username): + key = 'applet_host_accounts_{}_{}'.format(host_id, username) + cache.delete(key) + class AppletPublication(JMSBaseModel): - applet = models.ForeignKey('Applet', on_delete=models.PROTECT, related_name='publications', verbose_name=_('Applet')) - host = models.ForeignKey('AppletHost', on_delete=models.PROTECT, related_name='publications', verbose_name=_('Host')) + applet = models.ForeignKey('Applet', on_delete=models.PROTECT, related_name='publications', + verbose_name=_('Applet')) + host = models.ForeignKey('AppletHost', on_delete=models.PROTECT, related_name='publications', + verbose_name=_('Host')) status = models.CharField(max_length=16, default='ready', verbose_name=_('Status')) comment = models.TextField(default='', blank=True, verbose_name=_('Comment')) diff --git a/apps/terminal/signal_handlers.py b/apps/terminal/signal_handlers.py index 7868159a8..d98e21eee 100644 --- a/apps/terminal/signal_handlers.py +++ b/apps/terminal/signal_handlers.py @@ -4,17 +4,18 @@ from django.db.models.signals import post_save, post_delete from django.db.utils import ProgrammingError from django.dispatch import receiver +from django.utils.functional import LazyObject +from assets.models import Asset from common.signals import django_ready from common.utils import get_logger +from common.utils.connection import RedisPubSub from orgs.utils import tmp_to_builtin_org -from assets.models import Asset -from .utils import db_port_manager, DBPortManager from .models import Applet, AppletHost +from .utils import db_port_manager, DBPortManager db_port_manager: DBPortManager - logger = get_logger(__file__) @@ -27,6 +28,8 @@ def on_applet_host_create(sender, instance, created=False, **kwargs): with tmp_to_builtin_org(system=1): instance.generate_accounts() + applet_host_change_pub_sub.publish(True) + @receiver(post_save, sender=Applet) def on_applet_create(sender, instance, created=False, **kwargs): @@ -35,6 +38,8 @@ def on_applet_create(sender, instance, created=False, **kwargs): hosts = AppletHost.objects.all() instance.hosts.set(hosts) + applet_host_change_pub_sub.publish(True) + @receiver(django_ready) def init_db_port_mapper(sender, **kwargs): @@ -59,3 +64,22 @@ def on_db_app_delete(sender, instance, **kwargs): if not instance.category != 'database': return db_port_manager.pop(instance) + + +class AppletHostPubSub(LazyObject): + def _setup(self): + self._wrapped = RedisPubSub('fm.applet_host_change') + + +@receiver(django_ready) +def subscribe_applet_host_change(sender, **kwargs): + logger.debug("Start subscribe for expire node assets id mapping from memory") + + def on_change(message): + from terminal.connect_methods import ConnectMethodUtil + ConnectMethodUtil.refresh_methods() + + applet_host_change_pub_sub.subscribe(on_change) + + +applet_host_change_pub_sub = AppletHostPubSub()