mirror of https://github.com/jumpserver/jumpserver
				
				
				
			feat: 修改 Endpoint 获取 Manugs DB listen port 的逻辑
							parent
							
								
									b8ec60dea1
								
							
						
					
					
						commit
						57e12256e7
					
				| 
						 | 
				
			
			@ -83,9 +83,3 @@ class AppType(models.TextChoices):
 | 
			
		|||
        if AppCategory.is_xpack(category):
 | 
			
		||||
            return True
 | 
			
		||||
        return tp in ['oracle', 'postgresql', 'sqlserver']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OracleVersion(models.TextChoices):
 | 
			
		||||
    version_11g = '11g', '11g'
 | 
			
		||||
    version_12c = '12c', '12c'
 | 
			
		||||
    version_other = 'other', _('Other')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,6 @@ from common.mixins import CommonModelMixin
 | 
			
		|||
from common.tree import TreeNode
 | 
			
		||||
from common.utils import is_uuid
 | 
			
		||||
from assets.models import Asset, SystemUser
 | 
			
		||||
from ..const import OracleVersion
 | 
			
		||||
 | 
			
		||||
from .. import const
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -304,15 +303,6 @@ class Application(CommonModelMixin, OrgModelMixin, ApplicationTreeNodeMixin):
 | 
			
		|||
            target_ip = self.attrs.get('host')
 | 
			
		||||
        return target_ip
 | 
			
		||||
 | 
			
		||||
    def get_target_protocol_for_oracle(self):
 | 
			
		||||
        """ Oracle 类型需要单独处理,因为要携带版本号 """
 | 
			
		||||
        if not self.is_type(self.APP_TYPE.oracle):
 | 
			
		||||
            return
 | 
			
		||||
        version = self.attrs.get('version', OracleVersion.version_12c)
 | 
			
		||||
        if version == OracleVersion.version_other:
 | 
			
		||||
            return
 | 
			
		||||
        return 'oracle_%s' % version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ApplicationUser(SystemUser):
 | 
			
		||||
    class Meta:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,15 +2,9 @@ from rest_framework import serializers
 | 
			
		|||
from django.utils.translation import ugettext_lazy as _
 | 
			
		||||
 | 
			
		||||
from ..application_category import DBSerializer
 | 
			
		||||
from applications.const import OracleVersion
 | 
			
		||||
 | 
			
		||||
__all__ = ['OracleSerializer']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OracleSerializer(DBSerializer):
 | 
			
		||||
    version = serializers.ChoiceField(
 | 
			
		||||
        choices=OracleVersion.choices, default=OracleVersion.version_12c,
 | 
			
		||||
        allow_null=True, label=_('Version'),
 | 
			
		||||
        help_text=_('Magnus currently supports only 11g and 12c connections')
 | 
			
		||||
    )
 | 
			
		||||
    port = serializers.IntegerField(default=1521, label=_('Port'), allow_null=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,6 +49,9 @@ class DBPortManager(object):
 | 
			
		|||
        for port, db_id in mapper.items():
 | 
			
		||||
            if db_id == str(db.id):
 | 
			
		||||
                return port
 | 
			
		||||
        logger.warning(
 | 
			
		||||
            'Not matched db port, db_id: {}, mapper length: {}'.format(db.id, len(mapper))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_db_by_port(self, port):
 | 
			
		||||
        mapper = self.get_mapper()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -62,12 +62,15 @@ class ConnectionTokenMixin:
 | 
			
		|||
 | 
			
		||||
    def get_smart_endpoint(self, protocol, asset=None, application=None):
 | 
			
		||||
        if asset:
 | 
			
		||||
            target_instance = asset
 | 
			
		||||
            target_ip = asset.get_target_ip()
 | 
			
		||||
        elif application:
 | 
			
		||||
            target_instance = application
 | 
			
		||||
            target_ip = application.get_target_ip()
 | 
			
		||||
        else:
 | 
			
		||||
            target_instance = None
 | 
			
		||||
            target_ip = ''
 | 
			
		||||
        endpoint = EndpointRule.match_endpoint(target_ip, protocol, self.request)
 | 
			
		||||
        endpoint = EndpointRule.match_endpoint(target_instance, target_ip, protocol, self.request)
 | 
			
		||||
        return endpoint
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -47,11 +47,12 @@ class SmartEndpointViewMixin:
 | 
			
		|||
        return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
 | 
			
		||||
 | 
			
		||||
    def match_endpoint_by_target_ip(self):
 | 
			
		||||
        # 用来方便测试
 | 
			
		||||
        target_ip = self.request.GET.get('target_ip', '')
 | 
			
		||||
        target_ip = self.request.GET.get('target_ip', '')  # 支持target_ip参数,用来方便测试
 | 
			
		||||
        if not target_ip and callable(getattr(self.target_instance, 'get_target_ip', None)):
 | 
			
		||||
            target_ip = self.target_instance.get_target_ip()
 | 
			
		||||
        endpoint = EndpointRule.match_endpoint(target_ip, self.target_protocol, self.request)
 | 
			
		||||
        endpoint = EndpointRule.match_endpoint(
 | 
			
		||||
            self.target_instance, target_ip, self.target_protocol, self.request
 | 
			
		||||
        )
 | 
			
		||||
        return endpoint
 | 
			
		||||
 | 
			
		||||
    def get_target_instance(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -83,12 +84,7 @@ class SmartEndpointViewMixin:
 | 
			
		|||
            return instance
 | 
			
		||||
 | 
			
		||||
    def get_target_protocol(self):
 | 
			
		||||
        protocol = None
 | 
			
		||||
        if isinstance(self.target_instance, Application) and self.target_instance.is_type(Application.APP_TYPE.oracle):
 | 
			
		||||
            protocol = self.target_instance.get_target_protocol_for_oracle()
 | 
			
		||||
        if not protocol:
 | 
			
		||||
            protocol = self.request.GET.get('protocol')
 | 
			
		||||
        return protocol
 | 
			
		||||
        return self.request.GET.get('protocol')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,27 +1,23 @@
 | 
			
		|||
from django.db import models
 | 
			
		||||
from django.utils.translation import ugettext_lazy as _
 | 
			
		||||
from django.core.validators import MinValueValidator, MaxValueValidator
 | 
			
		||||
from applications.models import Application
 | 
			
		||||
from applications.utils import db_port_manager
 | 
			
		||||
from common.db.models import JMSModel
 | 
			
		||||
from common.db.fields import PortField
 | 
			
		||||
from common.utils.ip import contains_ip
 | 
			
		||||
from common.exceptions import JMSException
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Endpoint(JMSModel):
 | 
			
		||||
    name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
 | 
			
		||||
    host = models.CharField(max_length=256, blank=True, verbose_name=_('Host'))
 | 
			
		||||
    # disabled value=0
 | 
			
		||||
    # value=0 表示 disabled
 | 
			
		||||
    https_port = PortField(default=443, verbose_name=_('HTTPS Port'))
 | 
			
		||||
    http_port = PortField(default=80, verbose_name=_('HTTP Port'))
 | 
			
		||||
    ssh_port = PortField(default=2222, verbose_name=_('SSH Port'))
 | 
			
		||||
    rdp_port = PortField(default=3389, verbose_name=_('RDP Port'))
 | 
			
		||||
 | 
			
		||||
    # Todo: Delete
 | 
			
		||||
    mysql_port = PortField(default=33060, verbose_name=_('MySQL Port'))
 | 
			
		||||
    mariadb_port = PortField(default=33061, verbose_name=_('MariaDB Port'))
 | 
			
		||||
    postgresql_port = PortField(default=54320, verbose_name=_('PostgreSQL Port'))
 | 
			
		||||
    redis_port = PortField(default=63790, verbose_name=_('Redis Port'))
 | 
			
		||||
    oracle_11g_port = PortField(default=15211, verbose_name=_('Oracle 11g Port'))
 | 
			
		||||
    oracle_12c_port = PortField(default=15212, verbose_name=_('Oracle 12c Port'))
 | 
			
		||||
    comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
 | 
			
		||||
 | 
			
		||||
    default_id = '00000000-0000-0000-0000-000000000001'
 | 
			
		||||
| 
						 | 
				
			
			@ -33,12 +29,18 @@ class Endpoint(JMSModel):
 | 
			
		|||
    def __str__(self):
 | 
			
		||||
        return self.name
 | 
			
		||||
 | 
			
		||||
    def get_port(self, protocol):
 | 
			
		||||
        return getattr(self, f'{protocol}_port', 0)
 | 
			
		||||
 | 
			
		||||
    def get_oracle_port(self, version):
 | 
			
		||||
        protocol = f'oracle_{version}'
 | 
			
		||||
        return self.get_port(protocol)
 | 
			
		||||
    def get_port(self, target_instance, protocol):
 | 
			
		||||
        if protocol in ['https', 'http', 'ssh', 'rdp']:
 | 
			
		||||
            port = getattr(self, f'{protocol}_port', 0)
 | 
			
		||||
        elif isinstance(target_instance, Application) and target_instance.category_db:
 | 
			
		||||
            port = db_port_manager.get_port_by_db(target_instance)
 | 
			
		||||
            if port is None:
 | 
			
		||||
                error = 'No application port is matched, application id: {}' \
 | 
			
		||||
                        ''.format(target_instance.id)
 | 
			
		||||
                raise JMSException(error)
 | 
			
		||||
        else:
 | 
			
		||||
            port = 0
 | 
			
		||||
        return port
 | 
			
		||||
 | 
			
		||||
    def is_default(self):
 | 
			
		||||
        return str(self.id) == self.default_id
 | 
			
		||||
| 
						 | 
				
			
			@ -48,10 +50,10 @@ class Endpoint(JMSModel):
 | 
			
		|||
            return
 | 
			
		||||
        return super().delete(using, keep_parents)
 | 
			
		||||
 | 
			
		||||
    def is_valid_for(self, protocol):
 | 
			
		||||
    def is_valid_for(self, target_instance, protocol):
 | 
			
		||||
        if self.is_default():
 | 
			
		||||
            return True
 | 
			
		||||
        if self.host and self.get_port(protocol) != 0:
 | 
			
		||||
        if self.host and self.get_port(target_instance, protocol) != 0:
 | 
			
		||||
            return True
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -105,19 +107,19 @@ class EndpointRule(JMSModel):
 | 
			
		|||
        return f'{self.name}({self.priority})'
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def match(cls, target_ip, protocol):
 | 
			
		||||
    def match(cls, target_instance, target_ip, protocol):
 | 
			
		||||
        for endpoint_rule in cls.objects.all().prefetch_related('endpoint'):
 | 
			
		||||
            if not contains_ip(target_ip, endpoint_rule.ip_group):
 | 
			
		||||
                continue
 | 
			
		||||
            if not endpoint_rule.endpoint:
 | 
			
		||||
                continue
 | 
			
		||||
            if not endpoint_rule.endpoint.is_valid_for(protocol):
 | 
			
		||||
            if not endpoint_rule.endpoint.is_valid_for(target_instance, protocol):
 | 
			
		||||
                continue
 | 
			
		||||
            return endpoint_rule
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def match_endpoint(cls, target_ip, protocol, request=None):
 | 
			
		||||
        endpoint_rule = cls.match(target_ip, protocol)
 | 
			
		||||
    def match_endpoint(cls, target_instance, target_ip, protocol, request=None):
 | 
			
		||||
        endpoint_rule = cls.match(target_instance, target_ip, protocol)
 | 
			
		||||
        if endpoint_rule:
 | 
			
		||||
            endpoint = endpoint_rule.endpoint
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,25 +2,23 @@ from rest_framework import serializers
 | 
			
		|||
from django.utils.translation import ugettext_lazy as _
 | 
			
		||||
from common.drf.serializers import BulkModelSerializer
 | 
			
		||||
from acls.serializers.rules import ip_group_child_validator, ip_group_help_text
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from ..models import Endpoint, EndpointRule
 | 
			
		||||
 | 
			
		||||
__all__ = ['EndpointSerializer', 'EndpointRuleSerializer']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EndpointSerializer(BulkModelSerializer):
 | 
			
		||||
    # 解决 luna 处理繁琐的问题,oracle_port 返回匹配到的端口
 | 
			
		||||
    oracle_port = serializers.SerializerMethodField(label=_('Oracle port'))
 | 
			
		||||
    # 解决 luna 处理繁琐的问题, 返回 magnus 监听的当前 db 的 port
 | 
			
		||||
    magnus_listen_db_port = serializers.SerializerMethodField(label=_('Magnus listen db port'))
 | 
			
		||||
    magnus_listen_port_range = serializers.SerializerMethodField(label=_('Magnus Listen port range'))
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Endpoint
 | 
			
		||||
        fields_mini = ['id', 'name']
 | 
			
		||||
        fields_small = [
 | 
			
		||||
            'host',
 | 
			
		||||
            'https_port', 'http_port', 'ssh_port',
 | 
			
		||||
            'rdp_port', 'mysql_port', 'mariadb_port',
 | 
			
		||||
            'postgresql_port', 'redis_port',
 | 
			
		||||
            'oracle_11g_port', 'oracle_12c_port',
 | 
			
		||||
            'oracle_port',
 | 
			
		||||
            'https_port', 'http_port', 'ssh_port', 'rdp_port',
 | 
			
		||||
        ]
 | 
			
		||||
        fields = fields_mini + fields_small + [
 | 
			
		||||
            'comment', 'date_created', 'date_updated', 'created_by'
 | 
			
		||||
| 
						 | 
				
			
			@ -30,19 +28,20 @@ class EndpointSerializer(BulkModelSerializer):
 | 
			
		|||
            'http_port': {'default': 80},
 | 
			
		||||
            'ssh_port': {'default': 2222},
 | 
			
		||||
            'rdp_port': {'default': 3389},
 | 
			
		||||
            'mysql_port': {'default': 33060},
 | 
			
		||||
            'mariadb_port': {'default': 33061},
 | 
			
		||||
            'postgresql_port': {'default': 54320},
 | 
			
		||||
            'redis_port': {'default': 63790},
 | 
			
		||||
            'oracle_11g_port': {'default': 15211},
 | 
			
		||||
            'oracle_12c_port': {'default': 15212},
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def get_oracle_port(self, obj: Endpoint):
 | 
			
		||||
    def get_magnus_listen_db_port(self, obj: Endpoint):
 | 
			
		||||
        view = self.context.get('view')
 | 
			
		||||
        if not view or view.action not in ['smart']:
 | 
			
		||||
            return 0
 | 
			
		||||
        return obj.get_port(view.target_protocol)
 | 
			
		||||
        return obj.get_port(view.target_instance, view.target_protocol)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_magnus_listen_port_range(obj: Endpoint):
 | 
			
		||||
        port_start = settings.MAGNUS_DB_PORTS_START
 | 
			
		||||
        port_limit = settings.MAGNUS_DB_PORTS_LIMIT_COUNT
 | 
			
		||||
        port_end = port_start + port_limit + 1
 | 
			
		||||
        return f'{port_start} - {port_end}'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EndpointRuleSerializer(BulkModelSerializer):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue