From 57e12256e7a1a931cf16bcd5f1d9501c6899eca5 Mon Sep 17 00:00:00 2001 From: "Jiangjie.Bai" Date: Thu, 22 Sep 2022 15:52:47 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BF=AE=E6=94=B9=20Endpoint=20?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=20Manugs=20DB=20listen=20port=20=E7=9A=84?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/applications/const.py | 6 --- apps/applications/models/application.py | 10 ----- .../attrs/application_type/oracle.py | 6 --- apps/applications/utils/db_port_mapper.py | 3 ++ apps/authentication/api/connection_token.py | 5 ++- apps/terminal/api/endpoint.py | 14 +++---- apps/terminal/models/endpoint.py | 42 ++++++++++--------- apps/terminal/serializers/endpoint.py | 29 +++++++------ 8 files changed, 48 insertions(+), 67 deletions(-) diff --git a/apps/applications/const.py b/apps/applications/const.py index 4e0d2fe50..313477c25 100644 --- a/apps/applications/const.py +++ b/apps/applications/const.py @@ -83,9 +83,3 @@ class AppType(models.TextChoices): if AppCategory.is_xpack(category): return True return tp in ['oracle', 'postgresql', 'sqlserver'] - - -class OracleVersion(models.TextChoices): - version_11g = '11g', '11g' - version_12c = '12c', '12c' - version_other = 'other', _('Other') diff --git a/apps/applications/models/application.py b/apps/applications/models/application.py index 91507c2cb..3661188fe 100644 --- a/apps/applications/models/application.py +++ b/apps/applications/models/application.py @@ -10,7 +10,6 @@ from common.mixins import CommonModelMixin from common.tree import TreeNode from common.utils import is_uuid from assets.models import Asset, SystemUser -from ..const import OracleVersion from .. import const @@ -304,15 +303,6 @@ class Application(CommonModelMixin, OrgModelMixin, ApplicationTreeNodeMixin): target_ip = self.attrs.get('host') return target_ip - def get_target_protocol_for_oracle(self): - """ Oracle 类型需要单独处理,因为要携带版本号 """ - if not self.is_type(self.APP_TYPE.oracle): - return - version = self.attrs.get('version', OracleVersion.version_12c) - if version == OracleVersion.version_other: - return - return 'oracle_%s' % version - class ApplicationUser(SystemUser): class Meta: diff --git a/apps/applications/serializers/attrs/application_type/oracle.py b/apps/applications/serializers/attrs/application_type/oracle.py index fdc8016d2..c87c4904d 100644 --- a/apps/applications/serializers/attrs/application_type/oracle.py +++ b/apps/applications/serializers/attrs/application_type/oracle.py @@ -2,15 +2,9 @@ from rest_framework import serializers from django.utils.translation import ugettext_lazy as _ from ..application_category import DBSerializer -from applications.const import OracleVersion __all__ = ['OracleSerializer'] class OracleSerializer(DBSerializer): - version = serializers.ChoiceField( - choices=OracleVersion.choices, default=OracleVersion.version_12c, - allow_null=True, label=_('Version'), - help_text=_('Magnus currently supports only 11g and 12c connections') - ) port = serializers.IntegerField(default=1521, label=_('Port'), allow_null=True) diff --git a/apps/applications/utils/db_port_mapper.py b/apps/applications/utils/db_port_mapper.py index 0b6bdd93a..19fc7ce12 100644 --- a/apps/applications/utils/db_port_mapper.py +++ b/apps/applications/utils/db_port_mapper.py @@ -49,6 +49,9 @@ class DBPortManager(object): for port, db_id in mapper.items(): if db_id == str(db.id): return port + logger.warning( + 'Not matched db port, db_id: {}, mapper length: {}'.format(db.id, len(mapper)) + ) def get_db_by_port(self, port): mapper = self.get_mapper() diff --git a/apps/authentication/api/connection_token.py b/apps/authentication/api/connection_token.py index 2b4ce7e2b..e52f76577 100644 --- a/apps/authentication/api/connection_token.py +++ b/apps/authentication/api/connection_token.py @@ -62,12 +62,15 @@ class ConnectionTokenMixin: def get_smart_endpoint(self, protocol, asset=None, application=None): if asset: + target_instance = asset target_ip = asset.get_target_ip() elif application: + target_instance = application target_ip = application.get_target_ip() else: + target_instance = None target_ip = '' - endpoint = EndpointRule.match_endpoint(target_ip, protocol, self.request) + endpoint = EndpointRule.match_endpoint(target_instance, target_ip, protocol, self.request) return endpoint @staticmethod diff --git a/apps/terminal/api/endpoint.py b/apps/terminal/api/endpoint.py index ca745d412..37de98576 100644 --- a/apps/terminal/api/endpoint.py +++ b/apps/terminal/api/endpoint.py @@ -47,11 +47,12 @@ class SmartEndpointViewMixin: return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol) def match_endpoint_by_target_ip(self): - # 用来方便测试 - target_ip = self.request.GET.get('target_ip', '') + target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数,用来方便测试 if not target_ip and callable(getattr(self.target_instance, 'get_target_ip', None)): target_ip = self.target_instance.get_target_ip() - endpoint = EndpointRule.match_endpoint(target_ip, self.target_protocol, self.request) + endpoint = EndpointRule.match_endpoint( + self.target_instance, target_ip, self.target_protocol, self.request + ) return endpoint def get_target_instance(self): @@ -83,12 +84,7 @@ class SmartEndpointViewMixin: return instance def get_target_protocol(self): - protocol = None - if isinstance(self.target_instance, Application) and self.target_instance.is_type(Application.APP_TYPE.oracle): - protocol = self.target_instance.get_target_protocol_for_oracle() - if not protocol: - protocol = self.request.GET.get('protocol') - return protocol + return self.request.GET.get('protocol') class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet): diff --git a/apps/terminal/models/endpoint.py b/apps/terminal/models/endpoint.py index 305443821..98cc6a328 100644 --- a/apps/terminal/models/endpoint.py +++ b/apps/terminal/models/endpoint.py @@ -1,27 +1,23 @@ from django.db import models from django.utils.translation import ugettext_lazy as _ from django.core.validators import MinValueValidator, MaxValueValidator +from applications.models import Application +from applications.utils import db_port_manager from common.db.models import JMSModel from common.db.fields import PortField from common.utils.ip import contains_ip +from common.exceptions import JMSException class Endpoint(JMSModel): name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True) host = models.CharField(max_length=256, blank=True, verbose_name=_('Host')) - # disabled value=0 + # value=0 表示 disabled https_port = PortField(default=443, verbose_name=_('HTTPS Port')) http_port = PortField(default=80, verbose_name=_('HTTP Port')) ssh_port = PortField(default=2222, verbose_name=_('SSH Port')) rdp_port = PortField(default=3389, verbose_name=_('RDP Port')) - # Todo: Delete - mysql_port = PortField(default=33060, verbose_name=_('MySQL Port')) - mariadb_port = PortField(default=33061, verbose_name=_('MariaDB Port')) - postgresql_port = PortField(default=54320, verbose_name=_('PostgreSQL Port')) - redis_port = PortField(default=63790, verbose_name=_('Redis Port')) - oracle_11g_port = PortField(default=15211, verbose_name=_('Oracle 11g Port')) - oracle_12c_port = PortField(default=15212, verbose_name=_('Oracle 12c Port')) comment = models.TextField(default='', blank=True, verbose_name=_('Comment')) default_id = '00000000-0000-0000-0000-000000000001' @@ -33,12 +29,18 @@ class Endpoint(JMSModel): def __str__(self): return self.name - def get_port(self, protocol): - return getattr(self, f'{protocol}_port', 0) - - def get_oracle_port(self, version): - protocol = f'oracle_{version}' - return self.get_port(protocol) + def get_port(self, target_instance, protocol): + if protocol in ['https', 'http', 'ssh', 'rdp']: + port = getattr(self, f'{protocol}_port', 0) + elif isinstance(target_instance, Application) and target_instance.category_db: + port = db_port_manager.get_port_by_db(target_instance) + if port is None: + error = 'No application port is matched, application id: {}' \ + ''.format(target_instance.id) + raise JMSException(error) + else: + port = 0 + return port def is_default(self): return str(self.id) == self.default_id @@ -48,10 +50,10 @@ class Endpoint(JMSModel): return return super().delete(using, keep_parents) - def is_valid_for(self, protocol): + def is_valid_for(self, target_instance, protocol): if self.is_default(): return True - if self.host and self.get_port(protocol) != 0: + if self.host and self.get_port(target_instance, protocol) != 0: return True return False @@ -105,19 +107,19 @@ class EndpointRule(JMSModel): return f'{self.name}({self.priority})' @classmethod - def match(cls, target_ip, protocol): + def match(cls, target_instance, target_ip, protocol): for endpoint_rule in cls.objects.all().prefetch_related('endpoint'): if not contains_ip(target_ip, endpoint_rule.ip_group): continue if not endpoint_rule.endpoint: continue - if not endpoint_rule.endpoint.is_valid_for(protocol): + if not endpoint_rule.endpoint.is_valid_for(target_instance, protocol): continue return endpoint_rule @classmethod - def match_endpoint(cls, target_ip, protocol, request=None): - endpoint_rule = cls.match(target_ip, protocol) + def match_endpoint(cls, target_instance, target_ip, protocol, request=None): + endpoint_rule = cls.match(target_instance, target_ip, protocol) if endpoint_rule: endpoint = endpoint_rule.endpoint else: diff --git a/apps/terminal/serializers/endpoint.py b/apps/terminal/serializers/endpoint.py index 3d8e858ac..d32d21a13 100644 --- a/apps/terminal/serializers/endpoint.py +++ b/apps/terminal/serializers/endpoint.py @@ -2,25 +2,23 @@ from rest_framework import serializers from django.utils.translation import ugettext_lazy as _ from common.drf.serializers import BulkModelSerializer from acls.serializers.rules import ip_group_child_validator, ip_group_help_text +from django.conf import settings from ..models import Endpoint, EndpointRule __all__ = ['EndpointSerializer', 'EndpointRuleSerializer'] class EndpointSerializer(BulkModelSerializer): - # 解决 luna 处理繁琐的问题,oracle_port 返回匹配到的端口 - oracle_port = serializers.SerializerMethodField(label=_('Oracle port')) + # 解决 luna 处理繁琐的问题, 返回 magnus 监听的当前 db 的 port + magnus_listen_db_port = serializers.SerializerMethodField(label=_('Magnus listen db port')) + magnus_listen_port_range = serializers.SerializerMethodField(label=_('Magnus Listen port range')) class Meta: model = Endpoint fields_mini = ['id', 'name'] fields_small = [ 'host', - 'https_port', 'http_port', 'ssh_port', - 'rdp_port', 'mysql_port', 'mariadb_port', - 'postgresql_port', 'redis_port', - 'oracle_11g_port', 'oracle_12c_port', - 'oracle_port', + 'https_port', 'http_port', 'ssh_port', 'rdp_port', ] fields = fields_mini + fields_small + [ 'comment', 'date_created', 'date_updated', 'created_by' @@ -30,19 +28,20 @@ class EndpointSerializer(BulkModelSerializer): 'http_port': {'default': 80}, 'ssh_port': {'default': 2222}, 'rdp_port': {'default': 3389}, - 'mysql_port': {'default': 33060}, - 'mariadb_port': {'default': 33061}, - 'postgresql_port': {'default': 54320}, - 'redis_port': {'default': 63790}, - 'oracle_11g_port': {'default': 15211}, - 'oracle_12c_port': {'default': 15212}, } - def get_oracle_port(self, obj: Endpoint): + def get_magnus_listen_db_port(self, obj: Endpoint): view = self.context.get('view') if not view or view.action not in ['smart']: return 0 - return obj.get_port(view.target_protocol) + return obj.get_port(view.target_instance, view.target_protocol) + + @staticmethod + def get_magnus_listen_port_range(obj: Endpoint): + port_start = settings.MAGNUS_DB_PORTS_START + port_limit = settings.MAGNUS_DB_PORTS_LIMIT_COUNT + port_end = port_start + port_limit + 1 + return f'{port_start} - {port_end}' class EndpointRuleSerializer(BulkModelSerializer):