# -*- coding: utf-8 -*- """ @author: 猿小天 @contact: QQ:1638245306 @Created on: 2021/6/1 001 22:47 @Remark: 自定义序列化器 """ from rest_framework import serializers from rest_framework.fields import empty from rest_framework.request import Request from rest_framework.serializers import ModelSerializer from django.utils.functional import cached_property from rest_framework.utils.serializer_helpers import BindingDict from dvadmin.system.models import Users from django_restql.mixins import DynamicFieldsMixin class CustomModelSerializer(DynamicFieldsMixin, ModelSerializer): """ 增强DRF的ModelSerializer,可自动更新模型的审计字段记录 (1)self.request能获取到rest_framework.request.Request对象 """ # 修改人的审计字段名称, 默认modifier, 继承使用时可自定义覆盖 modifier_field_id = "modifier" modifier_name = serializers.SerializerMethodField(read_only=True) def get_modifier_name(self, instance): if not hasattr(instance, "modifier"): return None queryset = ( Users.objects.filter(id=instance.modifier) .values_list("name", flat=True) .first() ) if queryset: return queryset return None # 创建人的审计字段名称, 默认creator, 继承使用时可自定义覆盖 creator_field_id = "creator" creator_name = serializers.SlugRelatedField( slug_field="name", source="creator", read_only=True ) # 数据所属部门字段 dept_belong_id_field_name = "dept_belong_id" # 添加默认时间返回格式 create_datetime = serializers.DateTimeField( format="%Y-%m-%d %H:%M:%S", required=False, read_only=True ) update_datetime = serializers.DateTimeField( format="%Y-%m-%d %H:%M:%S", required=False ) def __init__(self, instance=None, data=empty, request=None, **kwargs): super().__init__(instance, data, **kwargs) self.request: Request = request or self.context.get("request", None) def save(self, **kwargs): return super().save(**kwargs) def create(self, validated_data): if self.request: if str(self.request.user) != "AnonymousUser": if self.modifier_field_id in self.fields.fields: validated_data[self.modifier_field_id] = self.get_request_user_id() if self.creator_field_id in self.fields.fields: validated_data[self.creator_field_id] = self.request.user if ( self.dept_belong_id_field_name in self.fields.fields and validated_data.get(self.dept_belong_id_field_name, None) is None ): validated_data[self.dept_belong_id_field_name] = getattr( self.request.user, "dept_id", None ) return super().create(validated_data) def update(self, instance, validated_data): if self.request: if hasattr(self.instance, self.modifier_field_id): setattr( self.instance, self.modifier_field_id, self.get_request_user_id() ) return super().update(instance, validated_data) def get_request_username(self): if getattr(self.request, "user", None): return getattr(self.request.user, "username", None) return None def get_request_name(self): if getattr(self.request, "user", None): return getattr(self.request.user, "name", None) return None def get_request_user_id(self): if getattr(self.request, "user", None): return getattr(self.request.user, "id", None) return None # @cached_property # def fields(self): # fields = BindingDict(self) # for key, value in self.get_fields().items(): # fields[key] = value # # if not hasattr(self, '_context'): # return fields # is_root = self.root == self # parent_is_list_root = self.parent == self.root and getattr(self.parent, 'many', False) # if not (is_root or parent_is_list_root): # return fields # # try: # request = self.request or self.context['request'] # except KeyError: # return fields # params = getattr( # request, 'query_params', getattr(request, 'GET', None) # ) # if params is None: # pass # try: # filter_fields = params.get('_fields', None).split(',') # except AttributeError: # filter_fields = None # # try: # omit_fields = params.get('_exclude', None).split(',') # except AttributeError: # omit_fields = [] # # existing = set(fields.keys()) # if filter_fields is None: # allowed = existing # else: # allowed = set(filter(None, filter_fields)) # # omitted = set(filter(None, omit_fields)) # for field in existing: # if field not in allowed: # fields.pop(field, None) # if field in omitted: # fields.pop(field, None) # # return fields