feat: 修改 Endpoint 获取 Manugs DB listen port 的逻辑

pull/8892/head
Jiangjie.Bai 2022-09-22 15:52:47 +08:00
parent b8ec60dea1
commit 57e12256e7
8 changed files with 48 additions and 67 deletions

View File

@ -83,9 +83,3 @@ class AppType(models.TextChoices):
if AppCategory.is_xpack(category):
return True
return tp in ['oracle', 'postgresql', 'sqlserver']
class OracleVersion(models.TextChoices):
version_11g = '11g', '11g'
version_12c = '12c', '12c'
version_other = 'other', _('Other')

View File

@ -10,7 +10,6 @@ from common.mixins import CommonModelMixin
from common.tree import TreeNode
from common.utils import is_uuid
from assets.models import Asset, SystemUser
from ..const import OracleVersion
from .. import const
@ -304,15 +303,6 @@ class Application(CommonModelMixin, OrgModelMixin, ApplicationTreeNodeMixin):
target_ip = self.attrs.get('host')
return target_ip
def get_target_protocol_for_oracle(self):
""" Oracle 类型需要单独处理,因为要携带版本号 """
if not self.is_type(self.APP_TYPE.oracle):
return
version = self.attrs.get('version', OracleVersion.version_12c)
if version == OracleVersion.version_other:
return
return 'oracle_%s' % version
class ApplicationUser(SystemUser):
class Meta:

View File

@ -2,15 +2,9 @@ from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from ..application_category import DBSerializer
from applications.const import OracleVersion
__all__ = ['OracleSerializer']
class OracleSerializer(DBSerializer):
version = serializers.ChoiceField(
choices=OracleVersion.choices, default=OracleVersion.version_12c,
allow_null=True, label=_('Version'),
help_text=_('Magnus currently supports only 11g and 12c connections')
)
port = serializers.IntegerField(default=1521, label=_('Port'), allow_null=True)

View File

@ -49,6 +49,9 @@ class DBPortManager(object):
for port, db_id in mapper.items():
if db_id == str(db.id):
return port
logger.warning(
'Not matched db port, db_id: {}, mapper length: {}'.format(db.id, len(mapper))
)
def get_db_by_port(self, port):
mapper = self.get_mapper()

View File

@ -62,12 +62,15 @@ class ConnectionTokenMixin:
def get_smart_endpoint(self, protocol, asset=None, application=None):
if asset:
target_instance = asset
target_ip = asset.get_target_ip()
elif application:
target_instance = application
target_ip = application.get_target_ip()
else:
target_instance = None
target_ip = ''
endpoint = EndpointRule.match_endpoint(target_ip, protocol, self.request)
endpoint = EndpointRule.match_endpoint(target_instance, target_ip, protocol, self.request)
return endpoint
@staticmethod

View File

@ -47,11 +47,12 @@ class SmartEndpointViewMixin:
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
def match_endpoint_by_target_ip(self):
# 用来方便测试
target_ip = self.request.GET.get('target_ip', '')
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试
if not target_ip and callable(getattr(self.target_instance, 'get_target_ip', None)):
target_ip = self.target_instance.get_target_ip()
endpoint = EndpointRule.match_endpoint(target_ip, self.target_protocol, self.request)
endpoint = EndpointRule.match_endpoint(
self.target_instance, target_ip, self.target_protocol, self.request
)
return endpoint
def get_target_instance(self):
@ -83,12 +84,7 @@ class SmartEndpointViewMixin:
return instance
def get_target_protocol(self):
protocol = None
if isinstance(self.target_instance, Application) and self.target_instance.is_type(Application.APP_TYPE.oracle):
protocol = self.target_instance.get_target_protocol_for_oracle()
if not protocol:
protocol = self.request.GET.get('protocol')
return protocol
return self.request.GET.get('protocol')
class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet):

View File

@ -1,27 +1,23 @@
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from applications.models import Application
from applications.utils import db_port_manager
from common.db.models import JMSModel
from common.db.fields import PortField
from common.utils.ip import contains_ip
from common.exceptions import JMSException
class Endpoint(JMSModel):
name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
host = models.CharField(max_length=256, blank=True, verbose_name=_('Host'))
# disabled value=0
# value=0 表示 disabled
https_port = PortField(default=443, verbose_name=_('HTTPS Port'))
http_port = PortField(default=80, verbose_name=_('HTTP Port'))
ssh_port = PortField(default=2222, verbose_name=_('SSH Port'))
rdp_port = PortField(default=3389, verbose_name=_('RDP Port'))
# Todo: Delete
mysql_port = PortField(default=33060, verbose_name=_('MySQL Port'))
mariadb_port = PortField(default=33061, verbose_name=_('MariaDB Port'))
postgresql_port = PortField(default=54320, verbose_name=_('PostgreSQL Port'))
redis_port = PortField(default=63790, verbose_name=_('Redis Port'))
oracle_11g_port = PortField(default=15211, verbose_name=_('Oracle 11g Port'))
oracle_12c_port = PortField(default=15212, verbose_name=_('Oracle 12c Port'))
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
default_id = '00000000-0000-0000-0000-000000000001'
@ -33,12 +29,18 @@ class Endpoint(JMSModel):
def __str__(self):
return self.name
def get_port(self, protocol):
return getattr(self, f'{protocol}_port', 0)
def get_oracle_port(self, version):
protocol = f'oracle_{version}'
return self.get_port(protocol)
def get_port(self, target_instance, protocol):
if protocol in ['https', 'http', 'ssh', 'rdp']:
port = getattr(self, f'{protocol}_port', 0)
elif isinstance(target_instance, Application) and target_instance.category_db:
port = db_port_manager.get_port_by_db(target_instance)
if port is None:
error = 'No application port is matched, application id: {}' \
''.format(target_instance.id)
raise JMSException(error)
else:
port = 0
return port
def is_default(self):
return str(self.id) == self.default_id
@ -48,10 +50,10 @@ class Endpoint(JMSModel):
return
return super().delete(using, keep_parents)
def is_valid_for(self, protocol):
def is_valid_for(self, target_instance, protocol):
if self.is_default():
return True
if self.host and self.get_port(protocol) != 0:
if self.host and self.get_port(target_instance, protocol) != 0:
return True
return False
@ -105,19 +107,19 @@ class EndpointRule(JMSModel):
return f'{self.name}({self.priority})'
@classmethod
def match(cls, target_ip, protocol):
def match(cls, target_instance, target_ip, protocol):
for endpoint_rule in cls.objects.all().prefetch_related('endpoint'):
if not contains_ip(target_ip, endpoint_rule.ip_group):
continue
if not endpoint_rule.endpoint:
continue
if not endpoint_rule.endpoint.is_valid_for(protocol):
if not endpoint_rule.endpoint.is_valid_for(target_instance, protocol):
continue
return endpoint_rule
@classmethod
def match_endpoint(cls, target_ip, protocol, request=None):
endpoint_rule = cls.match(target_ip, protocol)
def match_endpoint(cls, target_instance, target_ip, protocol, request=None):
endpoint_rule = cls.match(target_instance, target_ip, protocol)
if endpoint_rule:
endpoint = endpoint_rule.endpoint
else:

View File

@ -2,25 +2,23 @@ from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from common.drf.serializers import BulkModelSerializer
from acls.serializers.rules import ip_group_child_validator, ip_group_help_text
from django.conf import settings
from ..models import Endpoint, EndpointRule
__all__ = ['EndpointSerializer', 'EndpointRuleSerializer']
class EndpointSerializer(BulkModelSerializer):
# 解决 luna 处理繁琐的问题oracle_port 返回匹配到的端口
oracle_port = serializers.SerializerMethodField(label=_('Oracle port'))
# 解决 luna 处理繁琐的问题, 返回 magnus 监听的当前 db 的 port
magnus_listen_db_port = serializers.SerializerMethodField(label=_('Magnus listen db port'))
magnus_listen_port_range = serializers.SerializerMethodField(label=_('Magnus Listen port range'))
class Meta:
model = Endpoint
fields_mini = ['id', 'name']
fields_small = [
'host',
'https_port', 'http_port', 'ssh_port',
'rdp_port', 'mysql_port', 'mariadb_port',
'postgresql_port', 'redis_port',
'oracle_11g_port', 'oracle_12c_port',
'oracle_port',
'https_port', 'http_port', 'ssh_port', 'rdp_port',
]
fields = fields_mini + fields_small + [
'comment', 'date_created', 'date_updated', 'created_by'
@ -30,19 +28,20 @@ class EndpointSerializer(BulkModelSerializer):
'http_port': {'default': 80},
'ssh_port': {'default': 2222},
'rdp_port': {'default': 3389},
'mysql_port': {'default': 33060},
'mariadb_port': {'default': 33061},
'postgresql_port': {'default': 54320},
'redis_port': {'default': 63790},
'oracle_11g_port': {'default': 15211},
'oracle_12c_port': {'default': 15212},
}
def get_oracle_port(self, obj: Endpoint):
def get_magnus_listen_db_port(self, obj: Endpoint):
view = self.context.get('view')
if not view or view.action not in ['smart']:
return 0
return obj.get_port(view.target_protocol)
return obj.get_port(view.target_instance, view.target_protocol)
@staticmethod
def get_magnus_listen_port_range(obj: Endpoint):
port_start = settings.MAGNUS_DB_PORTS_START
port_limit = settings.MAGNUS_DB_PORTS_LIMIT_COUNT
port_end = port_start + port_limit + 1
return f'{port_start} - {port_end}'
class EndpointRuleSerializer(BulkModelSerializer):