From 3a02832e08b6aa2142ed9c2b42bf1ef0204214d2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=8C=BF=E5=B0=8F=E5=A4=A9?= <1638245306@qq.com>
Date: Thu, 12 Jan 2023 22:02:38 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBUG:=20=E6=96=B0=E5=A2=9E?=
 =?UTF-8?q?=E5=AF=BC=E5=85=A5=E5=92=8C=E5=AF=BC=E5=85=A5=E6=9B=B4=E6=96=B0?=
 =?UTF-8?q?bug=E4=BC=98=E5=8C=96?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 backend/dvadmin/utils/import_export.py       | 38 ++++++++++----
 backend/dvadmin/utils/import_export_mixin.py | 52 +++++++++++++++-----
 2 files changed, 68 insertions(+), 22 deletions(-)

diff --git a/backend/dvadmin/utils/import_export.py b/backend/dvadmin/utils/import_export.py
index 49534f5..2bd6e1e 100644
--- a/backend/dvadmin/utils/import_export.py
+++ b/backend/dvadmin/utils/import_export.py
@@ -1,10 +1,13 @@
 # -*- coding: utf-8 -*-
 import os
 import re
+from datetime import datetime
 
 import openpyxl
 from django.conf import settings
 
+from dvadmin.utils.validator import CustomValidationError
+
 
 def import_to_data(file_url, field_data, m2m_fields=None):
     """
@@ -18,9 +21,9 @@ def import_to_data(file_url, field_data, m2m_fields=None):
     file_path_dir = os.path.join(settings.BASE_DIR, file_url)
     workbook = openpyxl.load_workbook(file_path_dir)
     table = workbook[workbook.sheetnames[0]]
-    theader = tuple(table.values)[0]  # Excel的表头
-    is_update = '更新主键(勿改)' in theader  # 是否导入更新
-    if is_update is False:  # 不是更新时,删除id列
+    theader = tuple(table.values)[0] #Excel的表头
+    is_update = '更新主键(勿改)' in theader #是否导入更新
+    if is_update is False: #不是更新时,删除id列
         field_data.pop('id')
     # 获取参数映射
     validation_data_dict = {}
@@ -44,15 +47,30 @@ def import_to_data(file_url, field_data, m2m_fields=None):
         if i == 0:
             continue
         array = {}
-        for index, key in enumerate(field_data.keys()):
+        for index, item in enumerate(field_data.items()):
+            items = list(item)
+            key = items[0]
+            values = items[1]
+            value_type = 'str'
+            if isinstance(values, dict):
+                value_type = values.get('type','str')
             cell_value = table.cell(row=row + 1, column=index + 2).value
-            # 由于excel导入数字类型后,会出现数字加 .0 的,进行处理
-            if type(cell_value) is float and str(cell_value).split(".")[1] == "0":
-                cell_value = int(str(cell_value).split(".")[0])
-            if type(cell_value) is str:
-                cell_value = cell_value.strip(" \t\n\r")
-            if cell_value is None:
+            if cell_value is None or cell_value=='':
                 continue
+            elif value_type == 'date':
+                print(61, datetime.strptime(str(cell_value), '%Y-%m-%d %H:%M:%S').date())
+                try:
+                    cell_value = datetime.strptime(str(cell_value), '%Y-%m-%d %H:%M:%S').date()
+                except:
+                    raise CustomValidationError('日期格式不正确')
+            elif value_type == 'datetime':
+                cell_value = datetime.strptime(str(cell_value), '%Y-%m-%d %H:%M:%S')
+            else:
+            # 由于excel导入数字类型后,会出现数字加 .0 的,进行处理
+                if type(cell_value) is float and str(cell_value).split(".")[1] == "0":
+                    cell_value = int(str(cell_value).split(".")[0])
+                elif type(cell_value) is str:
+                    cell_value = cell_value.strip(" \t\n\r")
             if key in validation_data_dict:
                 array[key] = validation_data_dict.get(key, {}).get(cell_value, None)
                 if key in m2m_fields:
diff --git a/backend/dvadmin/utils/import_export_mixin.py b/backend/dvadmin/utils/import_export_mixin.py
index 40d2db5..44f51cc 100644
--- a/backend/dvadmin/utils/import_export_mixin.py
+++ b/backend/dvadmin/utils/import_export_mixin.py
@@ -57,11 +57,6 @@ class ImportSerializerMixin:
             length += 2.1 if ord(char) > 256 else 1
         return round(length, 1) if length <= self.export_column_width else self.export_column_width
 
-    @action(methods=['get'],detail=False)
-    def update_field(self,request:Request):
-        data = [{"label":value,"value":key} for key,value in self.import_field_dict.items()]
-        return DetailResponse(data=data)
-
     @action(methods=['get','post'],detail=False)
     @transaction.atomic  # Django 事务,防止出错
     def import_data(self, request: Request, *args, **kwargs):
@@ -153,7 +148,7 @@ class ImportSerializerMixin:
             for ele in data:
                 filter_dic = {'id':ele.get('id')}
                 instance = filter_dic and queryset.filter(**filter_dic).first()
-                print(instance)
+                # print(156,ele)
                 serializer = self.import_serializer_class(instance, data=ele, request=request)
                 serializer.is_valid(raise_exception=True)
                 serializer.save()
@@ -170,15 +165,48 @@ class ImportSerializerMixin:
         response["Access-Control-Expose-Headers"] = f"Content-Disposition"
         response["content-disposition"] = f'attachment;filename={quote(str(f"导出{get_verbose_name(queryset)}.xlsx"))}'
         wb = Workbook()
+        ws1 = wb.create_sheet("data", 1)
+        ws1.sheet_state = "hidden"
         ws = wb.active
         import_field_dict = {}
-        for key,val in self.import_field_dict.items():
-            if isinstance(val,dict):
-                import_field_dict[key] = val.get('title')
+        header_data = ["序号","更新主键(勿改)"]
+        hidden_header = ["#","id"]
+        #----设置选项----
+        validation_data_dict = {}
+        for index, item in enumerate(self.import_field_dict.items()):
+            items = list(item)
+            key = items[0]
+            value = items[1]
+            if isinstance(value, dict):
+                header_data.append(value.get("title"))
+                hidden_header.append(value.get('display'))
+                choices = value.get("choices", {})
+                if choices.get("data"):
+                    data_list = []
+                    data_list.extend(choices.get("data").keys())
+                    validation_data_dict[value.get("title")] = data_list
+                elif choices.get("queryset") and choices.get("values_name"):
+                    data_list = choices.get("queryset").values_list(choices.get("values_name"), flat=True)
+                    validation_data_dict[value.get("title")] = list(data_list)
+                else:
+                    continue
+                column_letter = get_column_letter(len(validation_data_dict))
+                dv = DataValidation(
+                    type="list",
+                    formula1=f"{quote_sheetname('data')}!${column_letter}$2:${column_letter}${len(validation_data_dict[value.get('title')]) + 1}",
+                    allow_blank=True,
+                )
+                ws.add_data_validation(dv)
+                dv.add(f"{get_column_letter(index + 3)}2:{get_column_letter(index + 3)}1048576")
             else:
-                import_field_dict[key] = val
-        header_data = ["序号","更新主键(勿改)", *import_field_dict.values()]
-        hidden_header = ["#","id", *import_field_dict.keys()]
+                header_data.append(value)
+                hidden_header.append(key)
+        # 添加数据列
+        ws1.append(list(validation_data_dict.keys()))
+        for index, validation_data in enumerate(validation_data_dict.values()):
+            for inx, ele in enumerate(validation_data):
+                ws1[f"{get_column_letter(index + 1)}{inx + 2}"] = ele
+        #--------
         df_len_max = [self.get_string_len(ele) for ele in header_data]
         row = get_column_letter(len(hidden_header) + 1)
         column = 1