diff --git a/apps/assets/serializers/asset/database.py b/apps/assets/serializers/asset/database.py index 17a122fd6..633126995 100644 --- a/apps/assets/serializers/asset/database.py +++ b/apps/assets/serializers/asset/database.py @@ -1,8 +1,7 @@ from django.utils.translation import gettext_lazy as _ from rest_framework import serializers -from rest_framework.serializers import ValidationError -from assets.models import Database +from assets.models import Database, Platform from assets.serializers.gateway import GatewayWithAccountSecretSerializer from .common import AssetSerializer @@ -20,13 +19,42 @@ class DatabaseSerializer(AssetSerializer): ] fields = AssetSerializer.Meta.fields + extra_fields - def validate(self, attrs): - platform = attrs.get('platform') - db_type_required = ('mongodb', 'postgresql') - if platform and getattr(platform, 'type') in db_type_required \ - and not attrs.get('db_name'): - raise ValidationError({'db_name': _('This field is required.')}) - return attrs + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.set_db_name_required() + + def get_platform(self): + platform = None + platform_id = None + + if getattr(self, 'initial_data', None): + platform_id = self.initial_data.get('platform') + if isinstance(platform_id, dict): + platform_id = platform_id.get('id') or platform_id.get('pk') + if not platform_id and self.instance: + platform = self.instance.platform + elif getattr(self, 'instance', None): + if isinstance(self.instance, list): + return + platform = self.instance.platform + elif self.context.get('request'): + platform_id = self.context['request'].query_params.get('platform') + + if not platform and platform_id: + platform = Platform.objects.filter(id=platform_id).first() + return platform + + def set_db_name_required(self): + db_field = self.fields.get('db_name') + if not db_field: + return + + platform = self.get_platform() + if not platform: + return + + if platform.type in ['mysql', 'mariadb']: + db_field.required = False class DatabaseWithGatewaySerializer(DatabaseSerializer):