feat: 修改 DBPortMapper 异常处理问题; DBListenPort API 迁移至 terminal app 中

pull/8892/head
Jiangjie.Bai 2022-09-22 18:47:16 +08:00
parent 7a6ed91f62
commit c1c70849e9
14 changed files with 249 additions and 224 deletions

View File

@ -1,19 +1,16 @@
# coding: utf-8
#
from orgs.mixins.api import OrgBulkModelViewSet
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from common.tree import TreeNodeSerializer
from common.mixins.api import SuggestionMixin
from ..utils import db_port_manager
from .. import serializers
from ..models import Application
__all__ = ['ApplicationViewSet', 'DBListenPortViewSet']
__all__ = ['ApplicationViewSet']
class ApplicationViewSet(SuggestionMixin, OrgBulkModelViewSet):
@ -41,30 +38,3 @@ class ApplicationViewSet(SuggestionMixin, OrgBulkModelViewSet):
tree_nodes = Application.create_tree_nodes(queryset, show_count=show_count)
serializer = self.get_serializer(tree_nodes, many=True)
return Response(serializer.data)
class DBListenPortViewSet(GenericViewSet):
rbac_perms = {
'GET': 'applications.view_application',
'list': 'applications.view_application',
'db_info': 'applications.view_application',
}
http_method_names = ['get', 'post']
def list(self, request, *args, **kwargs):
ports = db_port_manager.get_already_use_ports()
return Response(data=ports, status=status.HTTP_200_OK)
@action(methods=['post'], detail=False, url_path='db-info')
def db_info(self, request, *args, **kwargs):
port = request.data.get("port")
db, msg = db_port_manager.get_db_by_port(port)
if db:
serializer = serializers.AppSerializer(instance=db)
data = serializer.data
_status = status.HTTP_200_OK
else:
data = {'error': msg}
_status = status.HTTP_404_NOT_FOUND
return Response(data=data, status=_status)

View File

@ -1,36 +1,2 @@
# -*- coding: utf-8 -*-
#
from django.db.models.signals import post_save, post_delete
from common.signals import django_ready
from django.dispatch import receiver
from common.utils import get_logger
from .models import Application
from .utils import db_port_manager, DBPortManager
db_port_manager: DBPortManager
logger = get_logger(__file__)
@receiver(django_ready)
def init_db_port_mapper(sender, **kwargs):
logger.info('Init db port mapper')
db_port_manager.init()
@receiver(post_save, sender=Application)
def on_db_app_created(sender, instance: Application, created, **kwargs):
if not instance.category_db:
return
if not created:
return
db_port_manager.add(instance)
@receiver(post_delete, sender=Application)
def on_db_app_delete(sender, instance, **kwargs):
if not instance.category_db:
return
db_port_manager.pop(instance)
#

View File

@ -13,7 +13,6 @@ router.register(r'applications', api.ApplicationViewSet, 'application')
router.register(r'accounts', api.ApplicationAccountViewSet, 'application-account')
router.register(r'system-users-apps-relations', api.SystemUserAppRelationViewSet, 'system-users-apps-relation')
router.register(r'account-secrets', api.ApplicationAccountSecretViewSet, 'application-account-secret')
router.register(r'db-listen-ports', api.DBListenPortViewSet, 'db-listen-ports')
urlpatterns = [

View File

@ -8,3 +8,4 @@ from .storage import *
from .status import *
from .sharing import *
from .endpoint import *
from .db_listen_port import *

View File

@ -0,0 +1,36 @@
# coding: utf-8
#
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from ..utils import db_port_manager, DBPortManager
from applications import serializers
db_port_manager: DBPortManager
__all__ = ['DBListenPortViewSet']
class DBListenPortViewSet(GenericViewSet):
rbac_perms = {
'GET': 'applications.view_application',
'list': 'applications.view_application',
'db_info': 'applications.view_application',
}
http_method_names = ['get', 'post']
def list(self, request, *args, **kwargs):
ports = db_port_manager.get_already_use_ports()
return Response(data=ports, status=status.HTTP_200_OK)
@action(methods=['get'], detail=False, url_path='db-info')
def db_info(self, request, *args, **kwargs):
port = request.query_params.get("port")
db = db_port_manager.get_db_by_port(port)
serializer = serializers.AppSerializer(instance=db)
return Response(data=serializer.data, status=status.HTTP_200_OK)

View File

@ -2,12 +2,14 @@ from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from applications.models import Application
from applications.utils import db_port_manager
from ..utils import db_port_manager, DBPortManager
from common.db.models import JMSModel
from common.db.fields import PortField
from common.utils.ip import contains_ip
from common.exceptions import JMSException
db_port_manager: DBPortManager
class Endpoint(JMSModel):
name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
@ -34,10 +36,6 @@ class Endpoint(JMSModel):
port = getattr(self, f'{protocol}_port', 0)
elif isinstance(target_instance, Application) and target_instance.category_db:
port = db_port_manager.get_port_by_db(target_instance)
if port is None:
error = 'No application port is matched, application id: {}' \
''.format(target_instance.id)
raise JMSException(error)
else:
port = 0
return port

View File

@ -2,9 +2,10 @@ from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from common.drf.serializers import BulkModelSerializer
from acls.serializers.rules import ip_group_child_validator, ip_group_help_text
from applications.utils import db_port_manager
from ..utils import db_port_manager
from ..models import Endpoint, EndpointRule
__all__ = ['EndpointSerializer', 'EndpointRuleSerializer']

View File

@ -1,2 +1,36 @@
# -*- coding: utf-8 -*-
#
from django.db.models.signals import post_save, post_delete
from common.signals import django_ready
from django.dispatch import receiver
from common.utils import get_logger
from .models import Application
from .utils import db_port_manager, DBPortManager
db_port_manager: DBPortManager
logger = get_logger(__file__)
@receiver(django_ready)
def init_db_port_mapper(sender, **kwargs):
logger.info('Init db port mapper')
db_port_manager.init()
@receiver(post_save, sender=Application)
def on_db_app_created(sender, instance: Application, created, **kwargs):
if not instance.category_db:
return
if not created:
return
db_port_manager.add(instance)
@receiver(post_delete, sender=Application)
def on_db_app_delete(sender, instance, **kwargs):
if not instance.category_db:
return
db_port_manager.pop(instance)

View File

@ -24,6 +24,7 @@ router.register(r'session-sharings', api.SessionSharingViewSet, 'session-sharing
router.register(r'session-join-records', api.SessionJoinRecordsViewSet, 'session-sharing-record')
router.register(r'endpoints', api.EndpointViewSet, 'endpoint')
router.register(r'endpoint-rules', api.EndpointRuleViewSet, 'endpoint-rule')
router.register(r'db-listen-ports', api.DBListenPortViewSet, 'db-listen-ports')
urlpatterns = [
path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'),

View File

@ -0,0 +1,4 @@
from .components import *
from .common import *
from .session_replay import *
from .db_port_mapper import *

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
#
from common.utils import get_logger
from .. import const
from tickets.models import TicketSession
logger = get_logger(__name__)
class ComputeStatUtil:
# system status
@staticmethod
def _common_compute_system_status(value, thresholds):
if thresholds[0] <= value <= thresholds[1]:
return const.ComponentStatusChoices.normal.value
elif thresholds[1] < value <= thresholds[2]:
return const.ComponentStatusChoices.high.value
else:
return const.ComponentStatusChoices.critical.value
@classmethod
def _compute_system_stat_status(cls, stat):
system_stat_thresholds_mapper = {
'cpu_load': [0, 5, 20],
'memory_used': [0, 85, 95],
'disk_used': [0, 80, 99]
}
system_status = {}
for stat_key, thresholds in system_stat_thresholds_mapper.items():
stat_value = getattr(stat, stat_key)
if stat_value is None:
msg = 'stat: {}, stat_key: {}, stat_value: {}'
logger.debug(msg.format(stat, stat_key, stat_value))
stat_value = 0
status = cls._common_compute_system_status(stat_value, thresholds)
system_status[stat_key] = status
return system_status
@classmethod
def compute_component_status(cls, stat):
if not stat:
return const.ComponentStatusChoices.offline
system_status_values = cls._compute_system_stat_status(stat).values()
if const.ComponentStatusChoices.critical in system_status_values:
return const.ComponentStatusChoices.critical
elif const.ComponentStatusChoices.high in system_status_values:
return const.ComponentStatusChoices.high
else:
return const.ComponentStatusChoices.normal
def is_session_approver(session_id, user_id):
ticket = TicketSession.get_ticket_by_session_id(session_id)
if not ticket:
return False
ok = ticket.has_all_assignee(user_id)
return ok

View File

@ -1,124 +1,13 @@
# -*- coding: utf-8 -*-
#
import os
from itertools import groupby, chain
from django.conf import settings
from django.core.files.storage import default_storage
import jms_storage
from itertools import groupby
from common.utils import get_logger
from . import const
from .models import ReplayStorage
from tickets.models import TicketSession, TicketStep, TicketAssignee
from tickets.const import StepState
logger = get_logger(__name__)
def find_session_replay_local(session):
# 存在外部存储上,所有可能的路径名
session_paths = session.get_all_possible_relative_path()
# 存在本地存储上,所有可能的路径名
local_paths = session.get_all_possible_local_path()
for _local_path in chain(session_paths, local_paths):
if default_storage.exists(_local_path):
url = default_storage.url(_local_path)
return _local_path, url
return None, None
def download_session_replay(session):
replay_storages = ReplayStorage.objects.all()
configs = {
storage.name: storage.config
for storage in replay_storages
if not storage.type_null_or_server
}
if settings.SERVER_REPLAY_STORAGE:
configs['SERVER_REPLAY_STORAGE'] = settings.SERVER_REPLAY_STORAGE
if not configs:
msg = "Not found replay file, and not remote storage set"
return None, msg
storage = jms_storage.get_multi_object_storage(configs)
# 获取外部存储路径名
session_path = session.find_ok_relative_path_in_storage(storage)
if not session_path:
msg = "Not found session replay file"
return None, msg
# 通过外部存储路径名后缀,构造真实的本地存储路径
local_path = session.get_local_path_by_relative_path(session_path)
# 保存到storage的路径
target_path = os.path.join(default_storage.base_location, local_path)
target_dir = os.path.dirname(target_path)
if not os.path.isdir(target_dir):
os.makedirs(target_dir, exist_ok=True)
ok, err = storage.download(session_path, target_path)
if not ok:
msg = "Failed download replay file: {}".format(err)
logger.error(msg)
return None, msg
url = default_storage.url(local_path)
return local_path, url
def get_session_replay_url(session):
local_path, url = find_session_replay_local(session)
if local_path is None:
local_path, url = download_session_replay(session)
return local_path, url
class ComputeStatUtil:
# system status
@staticmethod
def _common_compute_system_status(value, thresholds):
if thresholds[0] <= value <= thresholds[1]:
return const.ComponentStatusChoices.normal.value
elif thresholds[1] < value <= thresholds[2]:
return const.ComponentStatusChoices.high.value
else:
return const.ComponentStatusChoices.critical.value
@classmethod
def _compute_system_stat_status(cls, stat):
system_stat_thresholds_mapper = {
'cpu_load': [0, 5, 20],
'memory_used': [0, 85, 95],
'disk_used': [0, 80, 99]
}
system_status = {}
for stat_key, thresholds in system_stat_thresholds_mapper.items():
stat_value = getattr(stat, stat_key)
if stat_value is None:
msg = 'stat: {}, stat_key: {}, stat_value: {}'
logger.debug(msg.format(stat, stat_key, stat_value))
stat_value = 0
status = cls._common_compute_system_status(stat_value, thresholds)
system_status[stat_key] = status
return system_status
@classmethod
def compute_component_status(cls, stat):
if not stat:
return const.ComponentStatusChoices.offline
system_status_values = cls._compute_system_stat_status(stat).values()
if const.ComponentStatusChoices.critical in system_status_values:
return const.ComponentStatusChoices.critical
elif const.ComponentStatusChoices.high in system_status_values:
return const.ComponentStatusChoices.high
else:
return const.ComponentStatusChoices.normal
class TypedComponentsStatusMetricsUtil(object):
def __init__(self):
self.components = []
@ -126,7 +15,7 @@ class TypedComponentsStatusMetricsUtil(object):
self.get_components()
def get_components(self):
from .models import Terminal
from ..models import Terminal
components = Terminal.objects.filter(is_deleted=False).order_by('type')
grouped_components = groupby(components, lambda c: c.type)
grouped_components = [(i[0], list(i[1])) for i in grouped_components]
@ -251,10 +140,3 @@ class ComponentsPrometheusMetricsUtil(TypedComponentsStatusMetricsUtil):
prometheus_metrics_text = '\n'.join(prometheus_metrics)
return prometheus_metrics_text
def is_session_approver(session_id, user_id):
ticket = TicketSession.get_ticket_by_session_id(session_id)
if not ticket:
return False
ok = ticket.has_all_assignee(user_id)
return ok

View File

@ -6,6 +6,7 @@ from applications.models import Application
from common.utils import get_logger
from common.utils import get_object_or_none
from orgs.utils import tmp_to_root_org
from common.exceptions import JMSException
logger = get_logger(__file__)
@ -22,24 +23,23 @@ class DBPortManager(object):
self.port_limit = settings.MAGNUS_DB_PORTS_LIMIT_COUNT
self.port_end = self.port_start + self.port_limit
# 可以使用的端口列表
self.all_usable_ports = [i for i in range(self.port_start, self.port_end+1)]
self.all_available_ports = list(range(self.port_start, self.port_end + 1))
@property
def magnus_listen_port_range(self):
return f'{self.port_start}-{self.port_end}'
def init(self):
db_ids = Application.objects.filter(category=AppCategory.db).values_list('id', flat=True)
with tmp_to_root_org():
db_ids = Application.objects.filter(category=AppCategory.db).values_list('id', flat=True)
db_ids = [str(i) for i in db_ids]
mapper = dict(zip(self.all_usable_ports, list(db_ids)))
mapper = dict(zip(self.all_available_ports, list(db_ids)))
self.set_mapper(mapper)
def add(self, db: Application):
mapper = self.get_mapper()
usable_port = self.get_next_usable_port()
if not usable_port:
return False
mapper.update({usable_port: str(db.id)})
available_port = self.get_next_available_port()
mapper.update({available_port: str(db.id)})
self.set_mapper(mapper)
return True
@ -54,43 +54,42 @@ class DBPortManager(object):
for port, db_id in mapper.items():
if db_id == str(db.id):
return port
logger.warning(
'Not matched db port, db_id: {}, mapper length: {}'.format(db.id, len(mapper))
raise JMSException(
'Not matched db port, db id: {}, mapper length: {}'.format(db.id, len(mapper))
)
def get_db_by_port(self, port):
mapper = self.get_mapper()
db_id = mapper.get(port, None)
if db_id:
with tmp_to_root_org():
db = get_object_or_none(Application, id=db_id)
if not db:
msg = 'Database not exists, database id: {}'.format(db_id)
else:
msg = ''
else:
db = None
msg = 'Port not in port-db mapper, port: {}'.format(port)
return db, msg
if not db_id:
raise JMSException('Database not in port-db mapper, port: {}'.format(port))
with tmp_to_root_org():
db = get_object_or_none(Application, id=db_id)
if not db:
raise JMSException('Database not exists, db id: {}'.format(db_id))
return db
def get_next_usable_port(self):
def get_next_available_port(self):
already_use_ports = self.get_already_use_ports()
usable_ports = sorted(list(set(self.all_usable_ports) - set(already_use_ports)))
if len(usable_ports) > 1:
port = usable_ports[0]
logger.debug('Get next usable port: {}'.format(port))
return port
msg = 'No port is usable, All usable port count: {}, Already use port count: {}'.format(
len(self.all_usable_ports), len(already_use_ports)
)
logger.warning(msg)
available_ports = sorted(list(set(self.all_available_ports) - set(already_use_ports)))
if len(available_ports) <= 0:
raise JMSException(
'No port is available, All available port count: {}, Already use port count: {}'
''.format(len(self.all_available_ports), len(already_use_ports))
)
port = available_ports[0]
logger.debug('Get next available port: {}'.format(port))
return port
def get_already_use_ports(self):
mapper = self.get_mapper()
return sorted(list(mapper.keys()))
def get_mapper(self):
mapper = cache.get(self.CACHE_KEY, {})
if not mapper:
# redis 可能被清空,重新初始化一下
self.init()
return cache.get(self.CACHE_KEY, {})
def set_mapper(self, value):

View File

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
#
import os
from itertools import groupby, chain
from django.conf import settings
from django.core.files.storage import default_storage
import jms_storage
from common.utils import get_logger
from ..models import ReplayStorage
logger = get_logger(__name__)
def find_session_replay_local(session):
# 存在外部存储上,所有可能的路径名
session_paths = session.get_all_possible_relative_path()
# 存在本地存储上,所有可能的路径名
local_paths = session.get_all_possible_local_path()
for _local_path in chain(session_paths, local_paths):
if default_storage.exists(_local_path):
url = default_storage.url(_local_path)
return _local_path, url
return None, None
def download_session_replay(session):
replay_storages = ReplayStorage.objects.all()
configs = {
storage.name: storage.config
for storage in replay_storages
if not storage.type_null_or_server
}
if settings.SERVER_REPLAY_STORAGE:
configs['SERVER_REPLAY_STORAGE'] = settings.SERVER_REPLAY_STORAGE
if not configs:
msg = "Not found replay file, and not remote storage set"
return None, msg
storage = jms_storage.get_multi_object_storage(configs)
# 获取外部存储路径名
session_path = session.find_ok_relative_path_in_storage(storage)
if not session_path:
msg = "Not found session replay file"
return None, msg
# 通过外部存储路径名后缀,构造真实的本地存储路径
local_path = session.get_local_path_by_relative_path(session_path)
# 保存到storage的路径
target_path = os.path.join(default_storage.base_location, local_path)
target_dir = os.path.dirname(target_path)
if not os.path.isdir(target_dir):
os.makedirs(target_dir, exist_ok=True)
ok, err = storage.download(session_path, target_path)
if not ok:
msg = "Failed download replay file: {}".format(err)
logger.error(msg)
return None, msg
url = default_storage.url(local_path)
return local_path, url
def get_session_replay_url(session):
local_path, url = find_session_replay_local(session)
if local_path is None:
local_path, url = download_session_replay(session)
return local_path, url