# -*- coding: utf-8 -*-
#

from rest_framework.request import Request

__all__ = ['SerializerMixin']


class SerializerMixin:
    """ 根据用户请求动作的不同,获取不同的 `serializer_class `"""

    action: str
    request: Request

    serializer_classes = None
    single_actions = ['put', 'retrieve', 'patch']

    def get_serializer_classes(self):
        classes = getattr(self, 'serializer_classes', None) or {}
        return dict(classes)

    def get_serializer_class_by_view_action(self):
        serializer_classes = self.get_serializer_classes()
        if serializer_classes is None:
            return None
        if not isinstance(serializer_classes, dict):
            return None
        serializer_classes = dict(serializer_classes)
        view_action = self.request.query_params.get('action') or self.action or 'list'
        if self.request.query_params.get('format'):
            view_action = 'retrieve'
        serializer_class = serializer_classes.get(view_action)

        if serializer_class is None:
            view_method = self.request.method.lower()
            serializer_class = serializer_classes.get(view_method)

        if serializer_class is None and view_action in self.single_actions:
            serializer_class = serializer_classes.get('single')
        if serializer_class is None:
            serializer_class = serializer_classes.get('display')
        if serializer_class is None:
            serializer_class = serializer_classes.get('default')
        return serializer_class

    def get_serializer_class(self):
        serializer_class = self.get_serializer_class_by_view_action()
        if serializer_class is None:
            serializer_class = super().get_serializer_class()
        return serializer_class