288 lines
11 KiB
Python
288 lines
11 KiB
Python
![]() |
import logging
|
||
|
import traceback
|
||
|
from types import FunctionType, MethodType
|
||
|
|
||
|
from rest_framework.exceptions import APIException as DRFAPIException
|
||
|
from rest_framework.request import Request
|
||
|
from rest_framework.views import APIView
|
||
|
|
||
![]() |
from ..utils import exceptions
|
||
|
from ..utils.model_util import ModelRelateUtils
|
||
![]() |
from .logging.view_logger import CustomerRelationshipViewLogger
|
||
|
from .response import SuccessResponse, ErrorResponse
|
||
|
from .serializers import CustomModelSerializer
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def op_exception_handler(ex, context):
|
||
|
"""
|
||
|
统一异常拦截处理
|
||
|
目的:(1)取消所有的500异常响应,统一响应为标准错误返回
|
||
|
(2)准确显示错误信息
|
||
|
:param ex:
|
||
|
:param context:
|
||
|
:return:
|
||
|
"""
|
||
|
msg = ''
|
||
|
if isinstance(ex, DRFAPIException):
|
||
|
# set_rollback()
|
||
|
msg = ex.detail
|
||
|
elif isinstance(ex, exceptions.APIException):
|
||
|
msg = ex.message
|
||
|
elif isinstance(ex, Exception):
|
||
|
logger.error(traceback.format_exc())
|
||
|
msg = str(ex)
|
||
|
return ErrorResponse(msg=msg)
|
||
|
|
||
|
|
||
|
class CustomAPIView(APIView):
|
||
|
"""
|
||
|
继承、增强DRF的APIView
|
||
|
"""
|
||
|
extra_permission_classes = ()
|
||
|
# 仅当GET方法时会触发该权限的校验
|
||
|
GET_permission_classes = ()
|
||
|
|
||
|
# 仅当POST方法时会触发该权限的校验
|
||
|
POST_permission_classes = ()
|
||
|
|
||
|
# 仅当DELETE方法时会触发该权限的校验
|
||
|
DELETE_permission_classes = ()
|
||
|
|
||
|
# 仅当PUT方法时会触发该权限的校验
|
||
|
PUT_permission_classes = ()
|
||
|
|
||
|
view_logger_classes = ()
|
||
|
|
||
|
def initial(self, request: Request, *args, **kwargs):
|
||
|
super().initial(request, *args, **kwargs)
|
||
|
self.check_extra_permissions(request)
|
||
|
self.check_method_extra_permissions(request)
|
||
|
|
||
|
def get_view_loggers(self, request: Request, *args, **kwargs):
|
||
|
logger_classes = self.view_logger_classes or []
|
||
|
if not logger_classes:
|
||
|
return []
|
||
|
view_loggers = [logger_class(view=self, request=request, *args, **kwargs) for logger_class in logger_classes]
|
||
|
return view_loggers
|
||
|
|
||
|
def handle_logging(self, request: Request, *args, **kwargs):
|
||
|
view_loggers = self.get_view_loggers(request, *args, **kwargs)
|
||
|
method = request.method.lower()
|
||
|
for view_logger in view_loggers:
|
||
|
view_logger.handle(request, *args, **kwargs)
|
||
|
logger_fun = getattr(view_logger, f'handle_{method}', None)
|
||
|
if logger_fun and isinstance(logger_fun, (FunctionType, MethodType)):
|
||
|
logger_fun(request, *args, **kwargs)
|
||
|
|
||
|
def get_extra_permissions(self):
|
||
|
return [permission() for permission in self.extra_permission_classes]
|
||
|
|
||
|
def check_extra_permissions(self, request: Request):
|
||
|
for permission in self.get_extra_permissions():
|
||
|
if not permission.has_permission(request, self):
|
||
|
self.permission_denied(
|
||
|
request, message=getattr(permission, 'message', None)
|
||
|
)
|
||
|
|
||
|
def get_method_extra_permissions(self):
|
||
|
_name = self.request.method.upper()
|
||
|
method_extra_permission_classes = getattr(self, f"{_name}_permission_classes", None)
|
||
|
if not method_extra_permission_classes:
|
||
|
return []
|
||
|
return [permission() for permission in method_extra_permission_classes]
|
||
|
|
||
|
def check_method_extra_permissions(self, request):
|
||
|
for permission in self.get_method_extra_permissions():
|
||
|
if not permission.has_permission(request, self):
|
||
|
self.permission_denied(
|
||
|
request, message=getattr(permission, 'message', None)
|
||
|
)
|
||
|
|
||
|
|
||
|
class BatchModelApIView(CustomAPIView):
|
||
|
"""
|
||
|
模型批量CRUD通用视图
|
||
|
"""
|
||
|
model = None
|
||
|
serializer_class = None
|
||
|
POST_serializer_class = None
|
||
|
PUT_serializer_class = None
|
||
|
field_name = 'instanceId'
|
||
|
instanceId_list_param_name = 'instanceIdList'
|
||
|
instance_info_param_name = 'info'
|
||
|
|
||
|
def get_serializer(self, *args, **kwargs):
|
||
|
if not self.request:
|
||
|
return None
|
||
|
serializer_class = getattr(self, f"{self.request.method}_serializer_class", None) or getattr(self,
|
||
|
'serializer_class')
|
||
|
serializer = serializer_class(*args, **kwargs)
|
||
|
if isinstance(serializer, CustomModelSerializer):
|
||
|
serializer.request = self.request
|
||
|
return serializer
|
||
|
|
||
|
def get(self, request: Request = None, *args, **kwargs):
|
||
|
data = self.get_serializer(self.model.objects.filter(**{f'{self.field_name}__in': request.data}),
|
||
|
many=True).data
|
||
|
return SuccessResponse(data=data)
|
||
|
|
||
|
def post(self, request: Request = None, *args, **kwargs):
|
||
|
data = []
|
||
|
for info in request.data:
|
||
|
serializer = self.get_serializer(data=info)
|
||
|
serializer.is_valid(raise_exception=True)
|
||
|
serializer.save()
|
||
|
data.append(serializer.data)
|
||
|
return SuccessResponse(data=data)
|
||
|
|
||
|
def put(self, request: Request = None, *args, **kwargs):
|
||
|
data = []
|
||
|
instanceId_list = request.data.get(self.instanceId_list_param_name, [])
|
||
|
info = request.data.get(self.instance_info_param_name, {})
|
||
|
for instanceId in instanceId_list:
|
||
|
serializer = self.get_serializer(
|
||
|
instance=self.model.objects.get(**{f'{self.field_name}': instanceId}),
|
||
|
data=info,
|
||
|
partial=True
|
||
|
)
|
||
|
serializer.is_valid(raise_exception=True)
|
||
|
serializer.save()
|
||
|
return SuccessResponse(data=instanceId_list)
|
||
|
|
||
|
def delete(self, request: Request = None, *args, **kwargs):
|
||
|
self.model.objects.filter(**{f'{self.field_name}__in': request.data}).delete()
|
||
|
return SuccessResponse(data=request.data)
|
||
|
|
||
|
|
||
|
class ModelRelationshipAPIView(CustomAPIView):
|
||
|
"""
|
||
|
模型关联关系通用CRUD视图
|
||
|
"""
|
||
|
model = None
|
||
|
through_model = None
|
||
|
relationship_model = None
|
||
|
|
||
|
relationship_serializer = None
|
||
|
field_name: str = None
|
||
|
from_field_name: str = 'instanceId'
|
||
|
to_field_name: str = None
|
||
|
relationship_field_values = ()
|
||
|
|
||
|
view_logger_classes = [CustomerRelationshipViewLogger, ]
|
||
|
|
||
|
def get_relationship_data(self, instanceId: str):
|
||
|
relationship_model_field_name = self.relationship_field_values[0]
|
||
|
params = {}
|
||
|
params[self.field_name] = instanceId
|
||
|
business_key_dict = self.through_model.objects.filter(**params).values(
|
||
|
*self.relationship_field_values).distinct()
|
||
|
business_key_list = [ele[relationship_model_field_name] for ele in business_key_dict]
|
||
|
|
||
|
params = {}
|
||
|
params[f"{self.to_field_name}__in"] = business_key_list
|
||
|
queryset = self.relationship_model.objects.filter(**params)
|
||
|
|
||
|
data = ModelRelateUtils.model_to_dict(queryset, self.relationship_serializer, default=[])
|
||
|
if 'creator' in self.relationship_field_values and 'ctime' in self.relationship_field_values:
|
||
|
for _index in range(len(data)):
|
||
|
ele = data[_index]
|
||
|
ele['relationship_creator'] = business_key_dict[_index]['creator']
|
||
|
ele['relationship_ctime'] = business_key_dict[_index]['ctime']
|
||
|
return data
|
||
|
|
||
|
def execute_method(self, execute: str, request: Request, instanceId: str, *args, **kwargs):
|
||
|
method = request.method.lower()
|
||
|
fun = None
|
||
|
if execute == 'before':
|
||
|
fun = getattr(self, f'before_{method}', None)
|
||
|
elif execute == 'handle':
|
||
|
fun = getattr(self, f'handle_{method}', None)
|
||
|
elif execute == 'after':
|
||
|
fun = getattr(self, f'after_{method}', None)
|
||
|
if fun and isinstance(fun, (FunctionType, MethodType)):
|
||
|
fun(request, instanceId, *args, **kwargs)
|
||
|
|
||
|
def do_request(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
self.execute_method('before', request, instanceId, *args, **kwargs)
|
||
|
self.execute_method('handle', request, instanceId, *args, **kwargs)
|
||
|
self.execute_method('after', request, instanceId, *args, **kwargs)
|
||
|
self.handle_logging(request, instanceId=instanceId, *args, **kwargs)
|
||
|
data = self.get_relationship_data(instanceId)
|
||
|
return SuccessResponse(data)
|
||
|
|
||
|
def get(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
return self.do_request(request, instanceId, *args, **kwargs)
|
||
|
|
||
|
def post(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
return self.do_request(request, instanceId, *args, **kwargs)
|
||
|
|
||
|
def put(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
return self.do_request(request, instanceId, *args, **kwargs)
|
||
|
|
||
|
def delete(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
return self.do_request(request, instanceId, *args, **kwargs)
|
||
|
|
||
|
|
||
|
class ModelRelationshipView(ModelRelationshipAPIView):
|
||
|
"""
|
||
|
模型关联关系通用CRUD视图
|
||
|
"""
|
||
|
|
||
|
def handle_get(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
data = self.get_relationship_data(instanceId)
|
||
|
return SuccessResponse(data)
|
||
|
|
||
|
def handle_post(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
relationship_model_field_name = self.relationship_field_values[0]
|
||
|
params = {}
|
||
|
params[f"{self.to_field_name}__in"] = request.data
|
||
|
queryset = self.relationship_model.objects.filter(**params)
|
||
|
|
||
|
exist_list = [getattr(ele, self.to_field_name) for ele in queryset]
|
||
|
bulk_info = []
|
||
|
for _id in exist_list:
|
||
|
info = {}
|
||
|
info[relationship_model_field_name] = _id
|
||
|
info[self.field_name] = instanceId
|
||
|
info['creator'] = request.user.username
|
||
|
bulk_info.append(self.through_model(**info))
|
||
|
self.through_model.objects.bulk_create(bulk_info)
|
||
|
data = self.get_relationship_data(instanceId)
|
||
|
return SuccessResponse(data)
|
||
|
|
||
|
def handle_put(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
relationship_model_field_name = self.relationship_field_values[0]
|
||
|
|
||
|
params1 = {}
|
||
|
params1[f"{self.field_name}"] = instanceId
|
||
|
params2 = {}
|
||
|
params2[f"{relationship_model_field_name}__in"] = request.data
|
||
|
|
||
|
relationships = self.through_model.objects.filter(**params1).exclude(**params2)
|
||
|
relationships.delete()
|
||
|
|
||
|
params = {}
|
||
|
params[f"{self.field_name}"] = instanceId
|
||
|
|
||
|
instanceId_dict = self.through_model.objects.filter(**params).values(*self.relationship_field_values).distinct()
|
||
|
instanceId_list = [ele.get(relationship_model_field_name) for ele in instanceId_dict]
|
||
|
create_list = list(set(request.data).difference(set(instanceId_list)))
|
||
|
for _id in create_list:
|
||
|
info = {}
|
||
|
info[relationship_model_field_name] = _id
|
||
|
info[self.field_name] = instanceId
|
||
|
info['creator'] = request.user.username
|
||
|
data = self.get_relationship_data(instanceId)
|
||
|
return SuccessResponse(data)
|
||
|
|
||
|
def handle_delete(self, request: Request, instanceId: str, *args, **kwargs):
|
||
|
relationship_model_field_name = self.relationship_field_values[0]
|
||
|
params = {}
|
||
|
params[f"{self.field_name}"] = instanceId
|
||
|
params[f"{relationship_model_field_name}__in"] = request.data
|
||
|
self.through_model.objects.filter(**params).delete()
|
||
|
data = self.get_relationship_data(instanceId)
|
||
|
return SuccessResponse(data)
|