mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			153 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			153 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Python
		
	
	
| from django.core.validators import MinValueValidator, MaxValueValidator
 | |
| from django.db import models
 | |
| from django.db.models import Prefetch
 | |
| from django.utils.translation import gettext_lazy as _
 | |
| 
 | |
| from assets.models import Asset
 | |
| from common.db.fields import PortField
 | |
| from common.db.models import JMSBaseModel
 | |
| from common.utils.ip import contains_ip
 | |
| 
 | |
| 
 | |
| class Endpoint(JMSBaseModel):
 | |
|     name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
 | |
|     host = models.CharField(max_length=256, blank=True, verbose_name=_('Host'))
 | |
|     # value=0 表示 disabled
 | |
|     https_port = PortField(default=443, verbose_name=_('HTTPS port'))
 | |
|     http_port = PortField(default=80, verbose_name=_('HTTP port'))
 | |
|     ssh_port = PortField(default=2222, verbose_name=_('SSH port'))
 | |
|     rdp_port = PortField(default=3389, verbose_name=_('RDP port'))
 | |
|     mysql_port = PortField(default=33061, verbose_name=_('MySQL port'))
 | |
|     mariadb_port = PortField(default=33062, verbose_name=_('MariaDB port'))
 | |
|     postgresql_port = PortField(default=54320, verbose_name=_('PostgreSQL port'))
 | |
|     redis_port = PortField(default=63790, verbose_name=_('Redis port'))
 | |
|     sqlserver_port = PortField(default=14330, verbose_name=_('SQLServer port'))
 | |
|     vnc_port = PortField(default=15900, verbose_name=_('VNC port'))
 | |
| 
 | |
|     comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
 | |
|     is_active = models.BooleanField(default=True, verbose_name=_('Active'))
 | |
| 
 | |
|     default_id = '00000000-0000-0000-0000-000000000001'
 | |
| 
 | |
|     class Meta:
 | |
|         verbose_name = _('Endpoint')
 | |
|         ordering = ('name',)
 | |
| 
 | |
|     def __str__(self):
 | |
|         return self.name
 | |
| 
 | |
|     def get_port(self, target_instance, protocol):
 | |
|         from terminal.utils import db_port_manager
 | |
|         from assets.const import DatabaseTypes, Protocol
 | |
| 
 | |
|         if isinstance(target_instance, Asset) and \
 | |
|                 target_instance.is_type(DatabaseTypes.ORACLE) and \
 | |
|                 protocol == Protocol.oracle:
 | |
|             port = db_port_manager.get_port_by_db(target_instance)
 | |
|         else:
 | |
|             if protocol in [Protocol.sftp, Protocol.telnet]:
 | |
|                 protocol = Protocol.ssh
 | |
|             port = getattr(self, f'{protocol}_port', 0)
 | |
|         return port
 | |
| 
 | |
|     def is_default(self):
 | |
|         return str(self.id) == self.default_id
 | |
| 
 | |
|     def delete(self, using=None, keep_parents=False):
 | |
|         if self.is_default():
 | |
|             return
 | |
|         return super().delete(using, keep_parents)
 | |
| 
 | |
|     def is_valid_for(self, target_instance, protocol):
 | |
|         if self.is_default():
 | |
|             return True
 | |
|         if self.get_port(target_instance, protocol) != 0:
 | |
|             return True
 | |
|         return False
 | |
| 
 | |
|     @classmethod
 | |
|     def get_or_create_default(cls, request=None):
 | |
|         data = {
 | |
|             'id': cls.default_id,
 | |
|             'name': 'Default',
 | |
|             'host': '',
 | |
|             'https_port': 0,
 | |
|             'http_port': 0,
 | |
|         }
 | |
|         endpoint, created = cls.objects.get_or_create(id=cls.default_id, defaults=data)
 | |
|         return endpoint
 | |
| 
 | |
|     @classmethod
 | |
|     def handle_endpoint_host(cls, endpoint, request=None):
 | |
|         if not endpoint.host and request:
 | |
|             # 动态添加 current request host
 | |
|             host_port = request.get_host()
 | |
|             # IPv6
 | |
|             if host_port.startswith('['):
 | |
|                 host = host_port.split(']:')[0].rstrip(']') + ']'
 | |
|             else:
 | |
|                 host = host_port.split(':')[0]
 | |
|             endpoint.host = host
 | |
|         return endpoint
 | |
| 
 | |
|     @classmethod
 | |
|     def match_by_instance_label(cls, instance, protocol, request=None):
 | |
|         from assets.models import Asset
 | |
|         from terminal.models import Session
 | |
|         if isinstance(instance, Session):
 | |
|             instance = instance.get_asset()
 | |
|         if not isinstance(instance, Asset):
 | |
|             return None
 | |
|         values = instance.labels.filter(label__name='endpoint').values_list('label__value', flat=True)
 | |
|         if not values:
 | |
|             return None
 | |
|         endpoints = cls.objects.filter(is_active=True, name__in=list(values)).order_by('-date_updated')
 | |
|         for endpoint in endpoints:
 | |
|             if endpoint.is_valid_for(instance, protocol):
 | |
|                 endpoint = cls.handle_endpoint_host(endpoint, request)
 | |
|                 return endpoint
 | |
| 
 | |
| 
 | |
| class EndpointRule(JMSBaseModel):
 | |
|     name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
 | |
|     ip_group = models.JSONField(default=list, verbose_name=_('IP group'))
 | |
|     priority = models.IntegerField(
 | |
|         verbose_name=_("Priority"), validators=[MinValueValidator(1), MaxValueValidator(100)],
 | |
|         unique=True, help_text=_("1-100, the lower the value will be match first"),
 | |
|     )
 | |
|     endpoint = models.ForeignKey(
 | |
|         'terminal.Endpoint', null=True, blank=True, related_name='rules',
 | |
|         on_delete=models.SET_NULL, verbose_name=_("Endpoint"),
 | |
|     )
 | |
|     comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
 | |
|     is_active = models.BooleanField(default=True, verbose_name=_('Active'))
 | |
| 
 | |
|     class Meta:
 | |
|         verbose_name = _('Endpoint rule')
 | |
|         ordering = ('priority', 'is_active', 'name')
 | |
| 
 | |
|     def __str__(self):
 | |
|         return f'{self.name}({self.priority})'
 | |
| 
 | |
|     @classmethod
 | |
|     def match(cls, target_instance, target_ip, protocol):
 | |
|         active_endpoints = Prefetch('endpoint', queryset=Endpoint.objects.filter(is_active=True))
 | |
|         for endpoint_rule in cls.objects.prefetch_related(active_endpoints).filter(is_active=True):
 | |
|             if not contains_ip(target_ip, endpoint_rule.ip_group):
 | |
|                 continue
 | |
|             if not endpoint_rule.endpoint:
 | |
|                 continue
 | |
|             if not endpoint_rule.endpoint.is_valid_for(target_instance, protocol):
 | |
|                 continue
 | |
|             return endpoint_rule
 | |
| 
 | |
|     @classmethod
 | |
|     def match_endpoint(cls, target_instance, target_ip, protocol, request=None):
 | |
|         endpoint_rule = cls.match(target_instance, target_ip, protocol)
 | |
|         if endpoint_rule:
 | |
|             endpoint = endpoint_rule.endpoint
 | |
|         else:
 | |
|             endpoint = Endpoint.get_or_create_default(request)
 | |
|         endpoint = Endpoint.handle_endpoint_host(endpoint, request)
 | |
|         return endpoint
 |