mirror of https://github.com/jumpserver/jumpserver
perf: mysql mariadb 数据库不再必填
parent
614e019f14
commit
23ccd6df8c
|
@ -1,8 +1,7 @@
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
from rest_framework.serializers import ValidationError
|
||||
|
||||
from assets.models import Database
|
||||
from assets.models import Database, Platform
|
||||
from assets.serializers.gateway import GatewayWithAccountSecretSerializer
|
||||
from .common import AssetSerializer
|
||||
|
||||
|
@ -20,13 +19,42 @@ class DatabaseSerializer(AssetSerializer):
|
|||
]
|
||||
fields = AssetSerializer.Meta.fields + extra_fields
|
||||
|
||||
def validate(self, attrs):
|
||||
platform = attrs.get('platform')
|
||||
db_type_required = ('mongodb', 'postgresql')
|
||||
if platform and getattr(platform, 'type') in db_type_required \
|
||||
and not attrs.get('db_name'):
|
||||
raise ValidationError({'db_name': _('This field is required.')})
|
||||
return attrs
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.set_db_name_required()
|
||||
|
||||
def get_platform(self):
|
||||
platform = None
|
||||
platform_id = None
|
||||
|
||||
if getattr(self, 'initial_data', None):
|
||||
platform_id = self.initial_data.get('platform')
|
||||
if isinstance(platform_id, dict):
|
||||
platform_id = platform_id.get('id') or platform_id.get('pk')
|
||||
if not platform_id and self.instance:
|
||||
platform = self.instance.platform
|
||||
elif getattr(self, 'instance', None):
|
||||
if isinstance(self.instance, list):
|
||||
return
|
||||
platform = self.instance.platform
|
||||
elif self.context.get('request'):
|
||||
platform_id = self.context['request'].query_params.get('platform')
|
||||
|
||||
if not platform and platform_id:
|
||||
platform = Platform.objects.filter(id=platform_id).first()
|
||||
return platform
|
||||
|
||||
def set_db_name_required(self):
|
||||
db_field = self.fields.get('db_name')
|
||||
if not db_field:
|
||||
return
|
||||
|
||||
platform = self.get_platform()
|
||||
if not platform:
|
||||
return
|
||||
|
||||
if platform.type in ['mysql', 'mariadb']:
|
||||
db_field.required = False
|
||||
|
||||
|
||||
class DatabaseWithGatewaySerializer(DatabaseSerializer):
|
||||
|
|
Loading…
Reference in New Issue