# -*- coding: utf-8 -*- # import phonenumbers from django.core.exceptions import ObjectDoesNotExist from django.db.models import Model from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from rest_framework.fields import empty from common.db.fields import TreeChoices, JSONManyToManyField as ModelJSONManyToManyField from common.utils import decrypt_password, is_uuid __all__ = [ "ReadableHiddenField", "EncryptedField", "LabeledChoiceField", "ObjectRelatedField", "BitChoicesField", "TreeChoicesField", "LabeledMultipleChoiceField", "PhoneField", "JSONManyToManyField", "LabelRelatedField", ] # ReadableHiddenField # ------------------- class ReadableHiddenField(serializers.HiddenField): """可读的 HiddenField""" def __init__(self, **kwargs): super().__init__(**kwargs) self.write_only = False def to_representation(self, value): if hasattr(value, "id"): return getattr(value, "id") return value class EncryptedField(serializers.CharField): def __init__(self, write_only=None, **kwargs): if write_only is None: write_only = True kwargs["write_only"] = write_only super().__init__(**kwargs) def to_internal_value(self, value): value = super().to_internal_value(value) return decrypt_password(value) class LabeledChoiceField(serializers.ChoiceField): def to_representation(self, key): if key is None: return key label = self.choices.get(key, key) return {"value": key, "label": label} def to_internal_value(self, data): if isinstance(data, dict): data = data.get("value") if isinstance(data, str) and "(" in data and data.endswith(")"): data = data.strip(")").split('(')[-1] return super(LabeledChoiceField, self).to_internal_value(data) class LabeledMultipleChoiceField(serializers.MultipleChoiceField): def __init__(self, **kwargs): super().__init__(**kwargs) self.choice_mapper = { key: value for key, value in self.choices.items() } def to_representation(self, keys): if keys is None: return keys return [ {"value": key, "label": self.choice_mapper.get(key)} for key in keys ] def to_internal_value(self, data): if not data: return data if isinstance(data[0], dict): return [item.get("value") for item in data] else: return data class LabelRelatedField(serializers.RelatedField): def __init__(self, **kwargs): queryset = kwargs.pop("queryset", None) if queryset is None: from labels.models import LabeledResource queryset = LabeledResource.objects.all() kwargs = {**kwargs} read_only = kwargs.get("read_only", False) if not read_only: kwargs["queryset"] = queryset super().__init__(**kwargs) def to_file_representation(self, value): if value is None: return value return "{}:{}".format(value.get('name'), value.get('value')) def to_representation(self, value): if value is None: return value label = value.label if not label: return None return {'id': label.id, 'name': label.name, 'value': label.value, 'color': label.color} def to_internal_value(self, data): from labels.models import LabeledResource, Label if data is None: return data if isinstance(data, dict) and (data.get("id") or data.get("pk")): pk = data.get("id") or data.get("pk") label = Label.objects.get(pk=pk) elif is_uuid(data): label = Label.objects.get(pk=data) else: if isinstance(data, dict): k = data.get("name") v = data.get("value") elif isinstance(data, str) and ":" in data: k, v = [x.strip() for x in data.split(":", 1)] else: raise serializers.ValidationError(_("Invalid data type")) label, __ = Label.objects.get_or_create(name=k, value=v, defaults={'name': k, 'value': v}) return LabeledResource(label=label) class ObjectRelatedField(serializers.RelatedField): default_error_messages = { "required": _("This field is required."), "does_not_exist": _('Invalid pk "{pk_value}" - object does not exist.'), "incorrect_type": _("Incorrect type. Expected pk value, received {data_type}."), } def __init__(self, **kwargs): self.attrs = kwargs.pop("attrs", None) or ("id", "name") self.many = kwargs.get("many", False) super().__init__(**kwargs) def to_representation(self, value): data = {} for attr in self.attrs: if not hasattr(value, attr): continue data[attr] = getattr(value, attr) return data def to_internal_value(self, data): queryset = self.get_queryset() if isinstance(data, Model): return queryset.get(pk=data.pk) if not isinstance(data, dict): pk = data else: pk = data.get("id") or data.get("pk") or data.get(self.attrs[0]) try: if isinstance(data, bool): raise TypeError return queryset.get(pk=pk) except ObjectDoesNotExist: self.fail("does_not_exist", pk_value=pk) except (TypeError, ValueError): self.fail("incorrect_type", data_type=type(pk).__name__) def get_schema(self): """ 为 drf-spectacular 提供 OpenAPI schema """ # 获取字段的基本信息 field_type = 'array' if self.many else 'object' if field_type == 'array': # 如果是多对多关系 return { 'type': 'array', 'items': self._get_openapi_item_schema(), 'description': getattr(self, 'help_text', ''), 'title': getattr(self, 'label', ''), } else: # 如果是一对一关系 return { 'type': 'object', 'properties': self._get_openapi_properties_schema(), 'description': getattr(self, 'help_text', ''), 'title': getattr(self, 'label', ''), } def _get_openapi_item_schema(self): """ 获取数组项的 OpenAPI schema """ return self._get_openapi_object_schema() def _get_openapi_object_schema(self): """ 获取对象的 OpenAPI schema """ properties = {} # 动态分析 attrs 中的属性类型 for attr in self.attrs: # 尝试从 queryset 的 model 中获取字段信息 field_type = self._infer_field_type(attr) properties[attr] = { 'type': field_type, 'description': f'{attr} field' } return { 'type': 'object', 'properties': properties, 'required': ['id'] if 'id' in self.attrs else [] } def _infer_field_type(self, attr_name): """ 智能推断字段类型 """ try: # 如果有 queryset,尝试从 model 中获取字段信息 if hasattr(self, 'queryset') and self.queryset is not None: model = self.queryset.model if hasattr(model, '_meta') and hasattr(model._meta, 'fields'): field = model._meta.get_field(attr_name) if field: return self._map_django_field_type(field) except Exception: pass # 如果没有 queryset 或无法获取字段信息,使用启发式规则 return self._heuristic_field_type(attr_name) def _map_django_field_type(self, field): """ 将 Django 字段类型映射到 OpenAPI 类型 """ field_type = type(field).__name__ # 整数类型 if 'Integer' in field_type or 'BigInteger' in field_type or 'SmallInteger' in field_type: return 'integer' # 浮点数类型 elif 'Float' in field_type or 'Decimal' in field_type: return 'number' # 布尔类型 elif 'Boolean' in field_type: return 'boolean' # 日期时间类型 elif 'DateTime' in field_type or 'Date' in field_type or 'Time' in field_type: return 'string' # 文件类型 elif 'File' in field_type or 'Image' in field_type: return 'string' # 其他类型默认为字符串 else: return 'string' def _heuristic_field_type(self, attr_name): """ 启发式推断字段类型 """ # 基于属性名的启发式规则 if attr_name in ['is_active', 'enabled', 'visible'] or attr_name.startswith('is_'): return 'boolean' elif attr_name in ['count', 'number', 'size', 'amount']: return 'integer' elif attr_name in ['price', 'rate', 'percentage']: return 'number' else: # 默认返回字符串类型 return 'string' def _get_openapi_properties_schema(self): """ 获取对象属性的 OpenAPI schema """ return self._get_openapi_object_schema()['properties'] class TreeChoicesField(serializers.MultipleChoiceField): def __init__(self, choice_cls, **kwargs): assert issubclass(choice_cls, TreeChoices) choices = [(c.name, c.label) for c in choice_cls] self.tree = choice_cls.tree() self._choice_cls = choice_cls super().__init__(choices=choices, **kwargs) def to_internal_value(self, data): if not data: return data if isinstance(data[0], dict): return [item.get("value") for item in data] else: return data class BitChoicesField(TreeChoicesField): """ 位字段 """ def to_representation(self, value): if isinstance(value, list) and len(value) == 1: # Swagger 会使用 field.choices.keys() 迭代传递进来 return [ {"value": c.name, "label": c.label} for c in self._choice_cls if c.name == value[0] ] return [ {"value": c.name, "label": c.label} for c in self._choice_cls if c.value & value == c.value ] def to_internal_value(self, data): if not isinstance(data, list): raise serializers.ValidationError(_("Invalid data type, should be list")) value = 0 if not data: return value if isinstance(data[0], dict): data = [d["value"] for d in data] # 所有的 if "all" in data: for c in self._choice_cls: value |= c.value return value name_value_map = {c.name: c.value for c in self._choice_cls} for name in data: if name not in name_value_map: raise serializers.ValidationError(_("Invalid choice: {}").format(name)) value |= name_value_map[name] return value def get_schema(self): """ 为 drf-spectacular 提供 OpenAPI schema """ return { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'value': {'type': 'string'}, 'label': {'type': 'string'} } }, 'description': getattr(self, 'help_text', ''), 'title': getattr(self, 'label', ''), } def run_validation(self, data=empty): """ 备注: 创建授权规则不包含 actions 字段时, 会使用默认值(AssetPermission 中设置), 会直接使用 ['connect', '...'] 等字段保存到数据库,导致类型错误 这里将获取到的值再执行一下 to_internal_value 方法, 转化为内部值 """ data = super().run_validation(data) if isinstance(data, int): return data value = self.to_internal_value(data) self.run_validators(value) return value class PhoneField(serializers.CharField): def to_internal_value(self, data): if isinstance(data, dict): code = data.get('code') phone = data.get('phone', '') if code and phone: code = code.replace('+', '') data = '+{}{}'.format(code, phone) else: data = phone if data: try: phone = phonenumbers.parse(data, 'CN') data = '+{}{}'.format(phone.country_code, phone.national_number) except phonenumbers.NumberParseException: data = '+86{}'.format(data) return super().to_internal_value(data) def to_representation(self, value): try: phone = phonenumbers.parse(value, 'CN') value = {'code': '+%s' % phone.country_code, 'phone': phone.national_number} except phonenumbers.NumberParseException: value = {'code': '+86', 'phone': value} return value class JSONManyToManyField(serializers.JSONField): def to_representation(self, manager): if manager is None: return manager value = manager.value if not isinstance(value, dict): return {"type": "ids", "ids": []} if value.get("type") == "ids": valid_ids = manager.all().values_list("id", flat=True) valid_ids = [str(i) for i in valid_ids] return {"type": "ids", "ids": valid_ids} return value def to_internal_value(self, data): if not data: data = {} try: data = super().to_internal_value(data) ModelJSONManyToManyField.check_value(data) except ValueError as e: raise serializers.ValidationError(e) return super().to_internal_value(data)