feat: Endpoint 支持标签匹配

feat: Endpoint 支持标签匹配

feat: Endpoint 支持标签匹配

feat: Endpoint 支持标签匹配

feat: Endpoint 添加帮助信息

feat: Endpoint 添加帮助信息
pull/8511/head
Jiangjie.Bai 2022-06-28 20:12:55 +08:00 committed by Jiangjie.Bai
parent e8363ddff8
commit 05826abf9d
8 changed files with 584 additions and 771 deletions

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:132e7f59a56d1cf5b2358b21b547861e872fa456164f2e0809120fb2b13f0ec1
size 128122
oid sha256:5a2f54edd26cd86ec150e5380ba8f9ac3f05ef144d11bef16459ae82a3ec0583
size 125818

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:002f6953ebbe368642f0ea3c383f617b5f998edf2238341be63393123d4be8a9
size 105894
oid sha256:441fff7ae0f44b24707348210678fe723299bf5abba943ea7080ba4a789a0801
size 103826

File diff suppressed because it is too large Load Diff

View File

@ -16,19 +16,41 @@ from .. import serializers
__all__ = ['EndpointViewSet', 'EndpointRuleViewSet']
class EndpointViewSet(JMSBulkModelViewSet):
filterset_fields = ('name', 'host')
search_fields = filterset_fields
serializer_class = serializers.EndpointSerializer
queryset = Endpoint.objects.all()
class SmartEndpointViewMixin:
get_serializer: callable
@action(methods=['get'], detail=False, permission_classes=[IsValidUser], url_path='smart')
def smart(self, request, *args, **kwargs):
protocol = request.GET.get('protocol')
if not protocol:
error = _('Not found protocol query params')
return Response(data={'error': error}, status=status.HTTP_404_NOT_FOUND)
endpoint = self.match_endpoint(request, protocol)
serializer = self.get_serializer(endpoint)
return Response(serializer.data)
def match_endpoint(self, request, protocol):
instance = self.get_target_instance(request)
endpoint = self.match_endpoint_by_label(instance, protocol)
if not endpoint:
endpoint = self.match_endpoint_by_target_ip(request, instance, protocol)
return endpoint
@staticmethod
def get_target_ip(request):
# 用来方便测试
target_ip = request.GET.get('target_ip')
if target_ip:
return target_ip
def match_endpoint_by_label(instance, protocol):
return Endpoint.match_by_instance_label(instance, protocol)
@staticmethod
def match_endpoint_by_target_ip(request, instance, protocol):
# 用来方便测试
target_ip = request.GET.get('target_ip', '')
if not target_ip and callable(getattr(instance, 'get_target_ip', None)):
target_ip = instance.get_target_ip(request)
endpoint = EndpointRule.match_endpoint(target_ip, protocol, request)
return endpoint
@staticmethod
def get_target_instance(request):
asset_id = request.GET.get('asset_id')
app_id = request.GET.get('app_id')
session_id = request.GET.get('session_id')
@ -48,25 +70,19 @@ class EndpointViewSet(JMSBulkModelViewSet):
elif session_id:
pk, model = session_id, Session
else:
return ''
pk, model = None, None
if not pk or not model:
return None
with tmp_to_root_org():
instance = get_object_or_404(model, pk=pk)
target_ip = instance.get_target_ip()
return target_ip
return instance
@action(methods=['get'], detail=False, permission_classes=[IsValidUser], url_path='smart')
def smart(self, request, *args, **kwargs):
protocol = request.GET.get('protocol')
if not protocol:
return Response(
data={'error': _('Not found protocol query params')},
status=status.HTTP_404_NOT_FOUND
)
target_ip = self.get_target_ip(request)
endpoint = EndpointRule.match_endpoint(target_ip, protocol, request)
serializer = self.get_serializer(endpoint)
return Response(serializer.data)
class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet):
filterset_fields = ('name', 'host')
search_fields = filterset_fields
serializer_class = serializers.EndpointSerializer
queryset = Endpoint.objects.all()
class EndpointRuleViewSet(JMSBulkModelViewSet):

View File

@ -33,13 +33,20 @@ class Endpoint(JMSModel):
return getattr(self, f'{protocol}_port', 0)
def is_default(self):
return self.id == self.default_id
return str(self.id) == self.default_id
def delete(self, using=None, keep_parents=False):
if self.is_default():
return
return super().delete(using, keep_parents)
def is_valid_for(self, protocol):
if self.is_default():
return True
if self.host and self.get_port(protocol) != 0:
return True
return False
@classmethod
def get_or_create_default(cls, request=None):
data = {
@ -54,6 +61,22 @@ class Endpoint(JMSModel):
endpoint.host = request.get_host().split(':')[0]
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol):
from assets.models import Asset
from terminal.models import Session
if isinstance(instance, Session):
instance = instance.get_asset_or_application()
if not isinstance(instance, Asset):
return None
values = instance.labels.filter(name='endpoint').values_list('value', flat=True)
if not values:
return None
endpoints = cls.objects.filter(name__in=values).order_by('-date_updated')
for endpoint in endpoints:
if endpoint.is_valid_for(protocol):
return endpoint
class EndpointRule(JMSModel):
name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
@ -82,11 +105,7 @@ class EndpointRule(JMSModel):
continue
if not endpoint_rule.endpoint:
continue
if endpoint_rule.endpoint.is_default():
return endpoint_rule
if not endpoint_rule.endpoint.host:
continue
if endpoint_rule.endpoint.get_port(protocol) == 0:
if not endpoint_rule.endpoint.is_valid_for(protocol):
continue
return endpoint_rule

View File

@ -196,10 +196,14 @@ class Session(OrgModelMixin):
def login_from_display(self):
return self.get_login_from_display()
def get_target_ip(self):
def get_asset_or_application(self):
instance = get_object_or_none(Asset, pk=self.asset_id)
if not instance:
instance = get_object_or_none(Application, pk=self.asset_id)
return instance
def get_target_ip(self):
instance = self.get_asset_or_application()
target_ip = instance.get_target_ip() if instance else ''
return target_ip

View File

@ -1,7 +1,7 @@
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from common.drf.serializers import BulkModelSerializer
from acls.serializers.rules import ip_group_help_text, ip_group_child_validator
from acls.serializers.rules import ip_group_child_validator, ip_group_help_text
from ..models import Endpoint, EndpointRule
__all__ = ['EndpointSerializer', 'EndpointRuleSerializer']
@ -34,8 +34,12 @@ class EndpointSerializer(BulkModelSerializer):
class EndpointRuleSerializer(BulkModelSerializer):
_ip_group_help_text = '{} <br> {}'.format(
ip_group_help_text,
_('If asset IP addresses under different endpoints conflict, use asset labels')
)
ip_group = serializers.ListField(
default=['*'], label=_('IP'), help_text=ip_group_help_text,
default=['*'], label=_('IP'), help_text=_ip_group_help_text,
child=serializers.CharField(max_length=1024, validators=[ip_group_child_validator])
)
endpoint_display = serializers.ReadOnlyField(source='endpoint.name', label=_('Endpoint'))