mirror of https://github.com/jumpserver/jumpserver
feat: 修改 Endpoint 获取 Manugs DB listen port 的逻辑
parent
b8ec60dea1
commit
57e12256e7
|
@ -83,9 +83,3 @@ class AppType(models.TextChoices):
|
|||
if AppCategory.is_xpack(category):
|
||||
return True
|
||||
return tp in ['oracle', 'postgresql', 'sqlserver']
|
||||
|
||||
|
||||
class OracleVersion(models.TextChoices):
|
||||
version_11g = '11g', '11g'
|
||||
version_12c = '12c', '12c'
|
||||
version_other = 'other', _('Other')
|
||||
|
|
|
@ -10,7 +10,6 @@ from common.mixins import CommonModelMixin
|
|||
from common.tree import TreeNode
|
||||
from common.utils import is_uuid
|
||||
from assets.models import Asset, SystemUser
|
||||
from ..const import OracleVersion
|
||||
|
||||
from .. import const
|
||||
|
||||
|
@ -304,15 +303,6 @@ class Application(CommonModelMixin, OrgModelMixin, ApplicationTreeNodeMixin):
|
|||
target_ip = self.attrs.get('host')
|
||||
return target_ip
|
||||
|
||||
def get_target_protocol_for_oracle(self):
|
||||
""" Oracle 类型需要单独处理,因为要携带版本号 """
|
||||
if not self.is_type(self.APP_TYPE.oracle):
|
||||
return
|
||||
version = self.attrs.get('version', OracleVersion.version_12c)
|
||||
if version == OracleVersion.version_other:
|
||||
return
|
||||
return 'oracle_%s' % version
|
||||
|
||||
|
||||
class ApplicationUser(SystemUser):
|
||||
class Meta:
|
||||
|
|
|
@ -2,15 +2,9 @@ from rest_framework import serializers
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from ..application_category import DBSerializer
|
||||
from applications.const import OracleVersion
|
||||
|
||||
__all__ = ['OracleSerializer']
|
||||
|
||||
|
||||
class OracleSerializer(DBSerializer):
|
||||
version = serializers.ChoiceField(
|
||||
choices=OracleVersion.choices, default=OracleVersion.version_12c,
|
||||
allow_null=True, label=_('Version'),
|
||||
help_text=_('Magnus currently supports only 11g and 12c connections')
|
||||
)
|
||||
port = serializers.IntegerField(default=1521, label=_('Port'), allow_null=True)
|
||||
|
|
|
@ -49,6 +49,9 @@ class DBPortManager(object):
|
|||
for port, db_id in mapper.items():
|
||||
if db_id == str(db.id):
|
||||
return port
|
||||
logger.warning(
|
||||
'Not matched db port, db_id: {}, mapper length: {}'.format(db.id, len(mapper))
|
||||
)
|
||||
|
||||
def get_db_by_port(self, port):
|
||||
mapper = self.get_mapper()
|
||||
|
|
|
@ -62,12 +62,15 @@ class ConnectionTokenMixin:
|
|||
|
||||
def get_smart_endpoint(self, protocol, asset=None, application=None):
|
||||
if asset:
|
||||
target_instance = asset
|
||||
target_ip = asset.get_target_ip()
|
||||
elif application:
|
||||
target_instance = application
|
||||
target_ip = application.get_target_ip()
|
||||
else:
|
||||
target_instance = None
|
||||
target_ip = ''
|
||||
endpoint = EndpointRule.match_endpoint(target_ip, protocol, self.request)
|
||||
endpoint = EndpointRule.match_endpoint(target_instance, target_ip, protocol, self.request)
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -47,11 +47,12 @@ class SmartEndpointViewMixin:
|
|||
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
|
||||
|
||||
def match_endpoint_by_target_ip(self):
|
||||
# 用来方便测试
|
||||
target_ip = self.request.GET.get('target_ip', '')
|
||||
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数,用来方便测试
|
||||
if not target_ip and callable(getattr(self.target_instance, 'get_target_ip', None)):
|
||||
target_ip = self.target_instance.get_target_ip()
|
||||
endpoint = EndpointRule.match_endpoint(target_ip, self.target_protocol, self.request)
|
||||
endpoint = EndpointRule.match_endpoint(
|
||||
self.target_instance, target_ip, self.target_protocol, self.request
|
||||
)
|
||||
return endpoint
|
||||
|
||||
def get_target_instance(self):
|
||||
|
@ -83,12 +84,7 @@ class SmartEndpointViewMixin:
|
|||
return instance
|
||||
|
||||
def get_target_protocol(self):
|
||||
protocol = None
|
||||
if isinstance(self.target_instance, Application) and self.target_instance.is_type(Application.APP_TYPE.oracle):
|
||||
protocol = self.target_instance.get_target_protocol_for_oracle()
|
||||
if not protocol:
|
||||
protocol = self.request.GET.get('protocol')
|
||||
return protocol
|
||||
return self.request.GET.get('protocol')
|
||||
|
||||
|
||||
class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet):
|
||||
|
|
|
@ -1,27 +1,23 @@
|
|||
from django.db import models
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||
from applications.models import Application
|
||||
from applications.utils import db_port_manager
|
||||
from common.db.models import JMSModel
|
||||
from common.db.fields import PortField
|
||||
from common.utils.ip import contains_ip
|
||||
from common.exceptions import JMSException
|
||||
|
||||
|
||||
class Endpoint(JMSModel):
|
||||
name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
|
||||
host = models.CharField(max_length=256, blank=True, verbose_name=_('Host'))
|
||||
# disabled value=0
|
||||
# value=0 表示 disabled
|
||||
https_port = PortField(default=443, verbose_name=_('HTTPS Port'))
|
||||
http_port = PortField(default=80, verbose_name=_('HTTP Port'))
|
||||
ssh_port = PortField(default=2222, verbose_name=_('SSH Port'))
|
||||
rdp_port = PortField(default=3389, verbose_name=_('RDP Port'))
|
||||
|
||||
# Todo: Delete
|
||||
mysql_port = PortField(default=33060, verbose_name=_('MySQL Port'))
|
||||
mariadb_port = PortField(default=33061, verbose_name=_('MariaDB Port'))
|
||||
postgresql_port = PortField(default=54320, verbose_name=_('PostgreSQL Port'))
|
||||
redis_port = PortField(default=63790, verbose_name=_('Redis Port'))
|
||||
oracle_11g_port = PortField(default=15211, verbose_name=_('Oracle 11g Port'))
|
||||
oracle_12c_port = PortField(default=15212, verbose_name=_('Oracle 12c Port'))
|
||||
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
|
||||
|
||||
default_id = '00000000-0000-0000-0000-000000000001'
|
||||
|
@ -33,12 +29,18 @@ class Endpoint(JMSModel):
|
|||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def get_port(self, protocol):
|
||||
return getattr(self, f'{protocol}_port', 0)
|
||||
|
||||
def get_oracle_port(self, version):
|
||||
protocol = f'oracle_{version}'
|
||||
return self.get_port(protocol)
|
||||
def get_port(self, target_instance, protocol):
|
||||
if protocol in ['https', 'http', 'ssh', 'rdp']:
|
||||
port = getattr(self, f'{protocol}_port', 0)
|
||||
elif isinstance(target_instance, Application) and target_instance.category_db:
|
||||
port = db_port_manager.get_port_by_db(target_instance)
|
||||
if port is None:
|
||||
error = 'No application port is matched, application id: {}' \
|
||||
''.format(target_instance.id)
|
||||
raise JMSException(error)
|
||||
else:
|
||||
port = 0
|
||||
return port
|
||||
|
||||
def is_default(self):
|
||||
return str(self.id) == self.default_id
|
||||
|
@ -48,10 +50,10 @@ class Endpoint(JMSModel):
|
|||
return
|
||||
return super().delete(using, keep_parents)
|
||||
|
||||
def is_valid_for(self, protocol):
|
||||
def is_valid_for(self, target_instance, protocol):
|
||||
if self.is_default():
|
||||
return True
|
||||
if self.host and self.get_port(protocol) != 0:
|
||||
if self.host and self.get_port(target_instance, protocol) != 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -105,19 +107,19 @@ class EndpointRule(JMSModel):
|
|||
return f'{self.name}({self.priority})'
|
||||
|
||||
@classmethod
|
||||
def match(cls, target_ip, protocol):
|
||||
def match(cls, target_instance, target_ip, protocol):
|
||||
for endpoint_rule in cls.objects.all().prefetch_related('endpoint'):
|
||||
if not contains_ip(target_ip, endpoint_rule.ip_group):
|
||||
continue
|
||||
if not endpoint_rule.endpoint:
|
||||
continue
|
||||
if not endpoint_rule.endpoint.is_valid_for(protocol):
|
||||
if not endpoint_rule.endpoint.is_valid_for(target_instance, protocol):
|
||||
continue
|
||||
return endpoint_rule
|
||||
|
||||
@classmethod
|
||||
def match_endpoint(cls, target_ip, protocol, request=None):
|
||||
endpoint_rule = cls.match(target_ip, protocol)
|
||||
def match_endpoint(cls, target_instance, target_ip, protocol, request=None):
|
||||
endpoint_rule = cls.match(target_instance, target_ip, protocol)
|
||||
if endpoint_rule:
|
||||
endpoint = endpoint_rule.endpoint
|
||||
else:
|
||||
|
|
|
@ -2,25 +2,23 @@ from rest_framework import serializers
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
from common.drf.serializers import BulkModelSerializer
|
||||
from acls.serializers.rules import ip_group_child_validator, ip_group_help_text
|
||||
from django.conf import settings
|
||||
from ..models import Endpoint, EndpointRule
|
||||
|
||||
__all__ = ['EndpointSerializer', 'EndpointRuleSerializer']
|
||||
|
||||
|
||||
class EndpointSerializer(BulkModelSerializer):
|
||||
# 解决 luna 处理繁琐的问题,oracle_port 返回匹配到的端口
|
||||
oracle_port = serializers.SerializerMethodField(label=_('Oracle port'))
|
||||
# 解决 luna 处理繁琐的问题, 返回 magnus 监听的当前 db 的 port
|
||||
magnus_listen_db_port = serializers.SerializerMethodField(label=_('Magnus listen db port'))
|
||||
magnus_listen_port_range = serializers.SerializerMethodField(label=_('Magnus Listen port range'))
|
||||
|
||||
class Meta:
|
||||
model = Endpoint
|
||||
fields_mini = ['id', 'name']
|
||||
fields_small = [
|
||||
'host',
|
||||
'https_port', 'http_port', 'ssh_port',
|
||||
'rdp_port', 'mysql_port', 'mariadb_port',
|
||||
'postgresql_port', 'redis_port',
|
||||
'oracle_11g_port', 'oracle_12c_port',
|
||||
'oracle_port',
|
||||
'https_port', 'http_port', 'ssh_port', 'rdp_port',
|
||||
]
|
||||
fields = fields_mini + fields_small + [
|
||||
'comment', 'date_created', 'date_updated', 'created_by'
|
||||
|
@ -30,19 +28,20 @@ class EndpointSerializer(BulkModelSerializer):
|
|||
'http_port': {'default': 80},
|
||||
'ssh_port': {'default': 2222},
|
||||
'rdp_port': {'default': 3389},
|
||||
'mysql_port': {'default': 33060},
|
||||
'mariadb_port': {'default': 33061},
|
||||
'postgresql_port': {'default': 54320},
|
||||
'redis_port': {'default': 63790},
|
||||
'oracle_11g_port': {'default': 15211},
|
||||
'oracle_12c_port': {'default': 15212},
|
||||
}
|
||||
|
||||
def get_oracle_port(self, obj: Endpoint):
|
||||
def get_magnus_listen_db_port(self, obj: Endpoint):
|
||||
view = self.context.get('view')
|
||||
if not view or view.action not in ['smart']:
|
||||
return 0
|
||||
return obj.get_port(view.target_protocol)
|
||||
return obj.get_port(view.target_instance, view.target_protocol)
|
||||
|
||||
@staticmethod
|
||||
def get_magnus_listen_port_range(obj: Endpoint):
|
||||
port_start = settings.MAGNUS_DB_PORTS_START
|
||||
port_limit = settings.MAGNUS_DB_PORTS_LIMIT_COUNT
|
||||
port_end = port_start + port_limit + 1
|
||||
return f'{port_start} - {port_end}'
|
||||
|
||||
|
||||
class EndpointRuleSerializer(BulkModelSerializer):
|
||||
|
|
Loading…
Reference in New Issue