mirror of https://github.com/jumpserver/jumpserver
feat: 修改 Endpoint 获取 Manugs DB listen port 的逻辑
parent
b8ec60dea1
commit
57e12256e7
|
@ -83,9 +83,3 @@ class AppType(models.TextChoices):
|
||||||
if AppCategory.is_xpack(category):
|
if AppCategory.is_xpack(category):
|
||||||
return True
|
return True
|
||||||
return tp in ['oracle', 'postgresql', 'sqlserver']
|
return tp in ['oracle', 'postgresql', 'sqlserver']
|
||||||
|
|
||||||
|
|
||||||
class OracleVersion(models.TextChoices):
|
|
||||||
version_11g = '11g', '11g'
|
|
||||||
version_12c = '12c', '12c'
|
|
||||||
version_other = 'other', _('Other')
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ from common.mixins import CommonModelMixin
|
||||||
from common.tree import TreeNode
|
from common.tree import TreeNode
|
||||||
from common.utils import is_uuid
|
from common.utils import is_uuid
|
||||||
from assets.models import Asset, SystemUser
|
from assets.models import Asset, SystemUser
|
||||||
from ..const import OracleVersion
|
|
||||||
|
|
||||||
from .. import const
|
from .. import const
|
||||||
|
|
||||||
|
@ -304,15 +303,6 @@ class Application(CommonModelMixin, OrgModelMixin, ApplicationTreeNodeMixin):
|
||||||
target_ip = self.attrs.get('host')
|
target_ip = self.attrs.get('host')
|
||||||
return target_ip
|
return target_ip
|
||||||
|
|
||||||
def get_target_protocol_for_oracle(self):
|
|
||||||
""" Oracle 类型需要单独处理,因为要携带版本号 """
|
|
||||||
if not self.is_type(self.APP_TYPE.oracle):
|
|
||||||
return
|
|
||||||
version = self.attrs.get('version', OracleVersion.version_12c)
|
|
||||||
if version == OracleVersion.version_other:
|
|
||||||
return
|
|
||||||
return 'oracle_%s' % version
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationUser(SystemUser):
|
class ApplicationUser(SystemUser):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
|
@ -2,15 +2,9 @@ from rest_framework import serializers
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
from ..application_category import DBSerializer
|
from ..application_category import DBSerializer
|
||||||
from applications.const import OracleVersion
|
|
||||||
|
|
||||||
__all__ = ['OracleSerializer']
|
__all__ = ['OracleSerializer']
|
||||||
|
|
||||||
|
|
||||||
class OracleSerializer(DBSerializer):
|
class OracleSerializer(DBSerializer):
|
||||||
version = serializers.ChoiceField(
|
|
||||||
choices=OracleVersion.choices, default=OracleVersion.version_12c,
|
|
||||||
allow_null=True, label=_('Version'),
|
|
||||||
help_text=_('Magnus currently supports only 11g and 12c connections')
|
|
||||||
)
|
|
||||||
port = serializers.IntegerField(default=1521, label=_('Port'), allow_null=True)
|
port = serializers.IntegerField(default=1521, label=_('Port'), allow_null=True)
|
||||||
|
|
|
@ -49,6 +49,9 @@ class DBPortManager(object):
|
||||||
for port, db_id in mapper.items():
|
for port, db_id in mapper.items():
|
||||||
if db_id == str(db.id):
|
if db_id == str(db.id):
|
||||||
return port
|
return port
|
||||||
|
logger.warning(
|
||||||
|
'Not matched db port, db_id: {}, mapper length: {}'.format(db.id, len(mapper))
|
||||||
|
)
|
||||||
|
|
||||||
def get_db_by_port(self, port):
|
def get_db_by_port(self, port):
|
||||||
mapper = self.get_mapper()
|
mapper = self.get_mapper()
|
||||||
|
|
|
@ -62,12 +62,15 @@ class ConnectionTokenMixin:
|
||||||
|
|
||||||
def get_smart_endpoint(self, protocol, asset=None, application=None):
|
def get_smart_endpoint(self, protocol, asset=None, application=None):
|
||||||
if asset:
|
if asset:
|
||||||
|
target_instance = asset
|
||||||
target_ip = asset.get_target_ip()
|
target_ip = asset.get_target_ip()
|
||||||
elif application:
|
elif application:
|
||||||
|
target_instance = application
|
||||||
target_ip = application.get_target_ip()
|
target_ip = application.get_target_ip()
|
||||||
else:
|
else:
|
||||||
|
target_instance = None
|
||||||
target_ip = ''
|
target_ip = ''
|
||||||
endpoint = EndpointRule.match_endpoint(target_ip, protocol, self.request)
|
endpoint = EndpointRule.match_endpoint(target_instance, target_ip, protocol, self.request)
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -47,11 +47,12 @@ class SmartEndpointViewMixin:
|
||||||
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
|
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
|
||||||
|
|
||||||
def match_endpoint_by_target_ip(self):
|
def match_endpoint_by_target_ip(self):
|
||||||
# 用来方便测试
|
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数,用来方便测试
|
||||||
target_ip = self.request.GET.get('target_ip', '')
|
|
||||||
if not target_ip and callable(getattr(self.target_instance, 'get_target_ip', None)):
|
if not target_ip and callable(getattr(self.target_instance, 'get_target_ip', None)):
|
||||||
target_ip = self.target_instance.get_target_ip()
|
target_ip = self.target_instance.get_target_ip()
|
||||||
endpoint = EndpointRule.match_endpoint(target_ip, self.target_protocol, self.request)
|
endpoint = EndpointRule.match_endpoint(
|
||||||
|
self.target_instance, target_ip, self.target_protocol, self.request
|
||||||
|
)
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
def get_target_instance(self):
|
def get_target_instance(self):
|
||||||
|
@ -83,12 +84,7 @@ class SmartEndpointViewMixin:
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def get_target_protocol(self):
|
def get_target_protocol(self):
|
||||||
protocol = None
|
return self.request.GET.get('protocol')
|
||||||
if isinstance(self.target_instance, Application) and self.target_instance.is_type(Application.APP_TYPE.oracle):
|
|
||||||
protocol = self.target_instance.get_target_protocol_for_oracle()
|
|
||||||
if not protocol:
|
|
||||||
protocol = self.request.GET.get('protocol')
|
|
||||||
return protocol
|
|
||||||
|
|
||||||
|
|
||||||
class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet):
|
class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet):
|
||||||
|
|
|
@ -1,27 +1,23 @@
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||||
|
from applications.models import Application
|
||||||
|
from applications.utils import db_port_manager
|
||||||
from common.db.models import JMSModel
|
from common.db.models import JMSModel
|
||||||
from common.db.fields import PortField
|
from common.db.fields import PortField
|
||||||
from common.utils.ip import contains_ip
|
from common.utils.ip import contains_ip
|
||||||
|
from common.exceptions import JMSException
|
||||||
|
|
||||||
|
|
||||||
class Endpoint(JMSModel):
|
class Endpoint(JMSModel):
|
||||||
name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
|
name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
|
||||||
host = models.CharField(max_length=256, blank=True, verbose_name=_('Host'))
|
host = models.CharField(max_length=256, blank=True, verbose_name=_('Host'))
|
||||||
# disabled value=0
|
# value=0 表示 disabled
|
||||||
https_port = PortField(default=443, verbose_name=_('HTTPS Port'))
|
https_port = PortField(default=443, verbose_name=_('HTTPS Port'))
|
||||||
http_port = PortField(default=80, verbose_name=_('HTTP Port'))
|
http_port = PortField(default=80, verbose_name=_('HTTP Port'))
|
||||||
ssh_port = PortField(default=2222, verbose_name=_('SSH Port'))
|
ssh_port = PortField(default=2222, verbose_name=_('SSH Port'))
|
||||||
rdp_port = PortField(default=3389, verbose_name=_('RDP Port'))
|
rdp_port = PortField(default=3389, verbose_name=_('RDP Port'))
|
||||||
|
|
||||||
# Todo: Delete
|
|
||||||
mysql_port = PortField(default=33060, verbose_name=_('MySQL Port'))
|
|
||||||
mariadb_port = PortField(default=33061, verbose_name=_('MariaDB Port'))
|
|
||||||
postgresql_port = PortField(default=54320, verbose_name=_('PostgreSQL Port'))
|
|
||||||
redis_port = PortField(default=63790, verbose_name=_('Redis Port'))
|
|
||||||
oracle_11g_port = PortField(default=15211, verbose_name=_('Oracle 11g Port'))
|
|
||||||
oracle_12c_port = PortField(default=15212, verbose_name=_('Oracle 12c Port'))
|
|
||||||
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
|
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
|
||||||
|
|
||||||
default_id = '00000000-0000-0000-0000-000000000001'
|
default_id = '00000000-0000-0000-0000-000000000001'
|
||||||
|
@ -33,12 +29,18 @@ class Endpoint(JMSModel):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def get_port(self, protocol):
|
def get_port(self, target_instance, protocol):
|
||||||
return getattr(self, f'{protocol}_port', 0)
|
if protocol in ['https', 'http', 'ssh', 'rdp']:
|
||||||
|
port = getattr(self, f'{protocol}_port', 0)
|
||||||
def get_oracle_port(self, version):
|
elif isinstance(target_instance, Application) and target_instance.category_db:
|
||||||
protocol = f'oracle_{version}'
|
port = db_port_manager.get_port_by_db(target_instance)
|
||||||
return self.get_port(protocol)
|
if port is None:
|
||||||
|
error = 'No application port is matched, application id: {}' \
|
||||||
|
''.format(target_instance.id)
|
||||||
|
raise JMSException(error)
|
||||||
|
else:
|
||||||
|
port = 0
|
||||||
|
return port
|
||||||
|
|
||||||
def is_default(self):
|
def is_default(self):
|
||||||
return str(self.id) == self.default_id
|
return str(self.id) == self.default_id
|
||||||
|
@ -48,10 +50,10 @@ class Endpoint(JMSModel):
|
||||||
return
|
return
|
||||||
return super().delete(using, keep_parents)
|
return super().delete(using, keep_parents)
|
||||||
|
|
||||||
def is_valid_for(self, protocol):
|
def is_valid_for(self, target_instance, protocol):
|
||||||
if self.is_default():
|
if self.is_default():
|
||||||
return True
|
return True
|
||||||
if self.host and self.get_port(protocol) != 0:
|
if self.host and self.get_port(target_instance, protocol) != 0:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -105,19 +107,19 @@ class EndpointRule(JMSModel):
|
||||||
return f'{self.name}({self.priority})'
|
return f'{self.name}({self.priority})'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def match(cls, target_ip, protocol):
|
def match(cls, target_instance, target_ip, protocol):
|
||||||
for endpoint_rule in cls.objects.all().prefetch_related('endpoint'):
|
for endpoint_rule in cls.objects.all().prefetch_related('endpoint'):
|
||||||
if not contains_ip(target_ip, endpoint_rule.ip_group):
|
if not contains_ip(target_ip, endpoint_rule.ip_group):
|
||||||
continue
|
continue
|
||||||
if not endpoint_rule.endpoint:
|
if not endpoint_rule.endpoint:
|
||||||
continue
|
continue
|
||||||
if not endpoint_rule.endpoint.is_valid_for(protocol):
|
if not endpoint_rule.endpoint.is_valid_for(target_instance, protocol):
|
||||||
continue
|
continue
|
||||||
return endpoint_rule
|
return endpoint_rule
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def match_endpoint(cls, target_ip, protocol, request=None):
|
def match_endpoint(cls, target_instance, target_ip, protocol, request=None):
|
||||||
endpoint_rule = cls.match(target_ip, protocol)
|
endpoint_rule = cls.match(target_instance, target_ip, protocol)
|
||||||
if endpoint_rule:
|
if endpoint_rule:
|
||||||
endpoint = endpoint_rule.endpoint
|
endpoint = endpoint_rule.endpoint
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2,25 +2,23 @@ from rest_framework import serializers
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
from common.drf.serializers import BulkModelSerializer
|
from common.drf.serializers import BulkModelSerializer
|
||||||
from acls.serializers.rules import ip_group_child_validator, ip_group_help_text
|
from acls.serializers.rules import ip_group_child_validator, ip_group_help_text
|
||||||
|
from django.conf import settings
|
||||||
from ..models import Endpoint, EndpointRule
|
from ..models import Endpoint, EndpointRule
|
||||||
|
|
||||||
__all__ = ['EndpointSerializer', 'EndpointRuleSerializer']
|
__all__ = ['EndpointSerializer', 'EndpointRuleSerializer']
|
||||||
|
|
||||||
|
|
||||||
class EndpointSerializer(BulkModelSerializer):
|
class EndpointSerializer(BulkModelSerializer):
|
||||||
# 解决 luna 处理繁琐的问题,oracle_port 返回匹配到的端口
|
# 解决 luna 处理繁琐的问题, 返回 magnus 监听的当前 db 的 port
|
||||||
oracle_port = serializers.SerializerMethodField(label=_('Oracle port'))
|
magnus_listen_db_port = serializers.SerializerMethodField(label=_('Magnus listen db port'))
|
||||||
|
magnus_listen_port_range = serializers.SerializerMethodField(label=_('Magnus Listen port range'))
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Endpoint
|
model = Endpoint
|
||||||
fields_mini = ['id', 'name']
|
fields_mini = ['id', 'name']
|
||||||
fields_small = [
|
fields_small = [
|
||||||
'host',
|
'host',
|
||||||
'https_port', 'http_port', 'ssh_port',
|
'https_port', 'http_port', 'ssh_port', 'rdp_port',
|
||||||
'rdp_port', 'mysql_port', 'mariadb_port',
|
|
||||||
'postgresql_port', 'redis_port',
|
|
||||||
'oracle_11g_port', 'oracle_12c_port',
|
|
||||||
'oracle_port',
|
|
||||||
]
|
]
|
||||||
fields = fields_mini + fields_small + [
|
fields = fields_mini + fields_small + [
|
||||||
'comment', 'date_created', 'date_updated', 'created_by'
|
'comment', 'date_created', 'date_updated', 'created_by'
|
||||||
|
@ -30,19 +28,20 @@ class EndpointSerializer(BulkModelSerializer):
|
||||||
'http_port': {'default': 80},
|
'http_port': {'default': 80},
|
||||||
'ssh_port': {'default': 2222},
|
'ssh_port': {'default': 2222},
|
||||||
'rdp_port': {'default': 3389},
|
'rdp_port': {'default': 3389},
|
||||||
'mysql_port': {'default': 33060},
|
|
||||||
'mariadb_port': {'default': 33061},
|
|
||||||
'postgresql_port': {'default': 54320},
|
|
||||||
'redis_port': {'default': 63790},
|
|
||||||
'oracle_11g_port': {'default': 15211},
|
|
||||||
'oracle_12c_port': {'default': 15212},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_oracle_port(self, obj: Endpoint):
|
def get_magnus_listen_db_port(self, obj: Endpoint):
|
||||||
view = self.context.get('view')
|
view = self.context.get('view')
|
||||||
if not view or view.action not in ['smart']:
|
if not view or view.action not in ['smart']:
|
||||||
return 0
|
return 0
|
||||||
return obj.get_port(view.target_protocol)
|
return obj.get_port(view.target_instance, view.target_protocol)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_magnus_listen_port_range(obj: Endpoint):
|
||||||
|
port_start = settings.MAGNUS_DB_PORTS_START
|
||||||
|
port_limit = settings.MAGNUS_DB_PORTS_LIMIT_COUNT
|
||||||
|
port_end = port_start + port_limit + 1
|
||||||
|
return f'{port_start} - {port_end}'
|
||||||
|
|
||||||
|
|
||||||
class EndpointRuleSerializer(BulkModelSerializer):
|
class EndpointRuleSerializer(BulkModelSerializer):
|
||||||
|
|
Loading…
Reference in New Issue