mirror of https://github.com/jumpserver/jumpserver
65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
from django.utils.translation import gettext_lazy as _
|
|
from rest_framework import serializers
|
|
|
|
from assets.models import Database, Platform
|
|
from assets.serializers.gateway import GatewayWithAccountSecretSerializer
|
|
from .common import AssetSerializer
|
|
|
|
__all__ = ['DatabaseSerializer', 'DatabaseWithGatewaySerializer']
|
|
|
|
|
|
class DatabaseSerializer(AssetSerializer):
|
|
db_name = serializers.CharField(max_length=1024, label=_('Default database'), required=True)
|
|
|
|
class Meta(AssetSerializer.Meta):
|
|
model = Database
|
|
extra_fields = [
|
|
'db_name', 'use_ssl', 'ca_cert', 'client_cert',
|
|
'client_key', 'allow_invalid_cert'
|
|
]
|
|
fields = AssetSerializer.Meta.fields + extra_fields
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.set_db_name_required()
|
|
|
|
def get_platform(self):
|
|
platform = None
|
|
platform_id = None
|
|
|
|
if getattr(self, 'initial_data', None):
|
|
platform_id = self.initial_data.get('platform')
|
|
if isinstance(platform_id, dict):
|
|
platform_id = platform_id.get('id') or platform_id.get('pk')
|
|
if not platform_id and self.instance:
|
|
platform = self.instance.platform
|
|
elif getattr(self, 'instance', None):
|
|
if isinstance(self.instance, list):
|
|
return
|
|
platform = self.instance.platform
|
|
elif self.context.get('request'):
|
|
platform_id = self.context['request'].query_params.get('platform')
|
|
|
|
if not platform and platform_id:
|
|
platform = Platform.objects.filter(id=platform_id).first()
|
|
return platform
|
|
|
|
def set_db_name_required(self):
|
|
db_field = self.fields.get('db_name')
|
|
if not db_field:
|
|
return
|
|
|
|
platform = self.get_platform()
|
|
if not platform:
|
|
return
|
|
|
|
if platform.type in ['mysql', 'mariadb']:
|
|
db_field.required = False
|
|
|
|
|
|
class DatabaseWithGatewaySerializer(DatabaseSerializer):
|
|
gateway = GatewayWithAccountSecretSerializer()
|
|
|
|
class Meta(DatabaseSerializer.Meta):
|
|
fields = DatabaseSerializer.Meta.fields + ['gateway']
|