feat(excel): 添加Excel导入/导出 (#5124)

* refactor(drf_renderer): 添加 ExcelRenderer 支持导出excel文件格式; 优化CSVRenderer, 抽象 BaseRenderer

* perf(renderer): 支持导出资源详情

* refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser

* refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser 2

* perf(renderer): 捕获renderer处理异常

* perf: 添加excel依赖包

* perf(drf): 优化导入导出错误日志

* perf: 添加依赖包 pyexcel-io==0.6.4

* perf: 添加依赖包pyexcel-xlsx==0.6.0

* feat: 修改drf/renderer&parser变量命名

* feat: 修改drf/renderer的bug

* feat: 修改drf/renderer&parser变量命名

Co-authored-by: Bai <bugatti_it@163.com>
pull/5177/head
fit2bot 2020-12-07 15:23:05 +08:00 committed by GitHub
parent 619b521ea1
commit 43b5e97b95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 339 additions and 186 deletions

View File

@ -1 +1,2 @@
from .csv import *
from .csv import *
from .excel import *

View File

@ -0,0 +1,132 @@
import abc
import json
import codecs
from django.utils.translation import ugettext_lazy as _
from rest_framework.parsers import BaseParser
from rest_framework import status
from rest_framework.exceptions import ParseError, APIException
from common.utils import get_logger
logger = get_logger(__file__)
class FileContentOverflowedError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_code = 'file_content_overflowed'
default_detail = _('The file content overflowed (The maximum length `{}` bytes)')
class BaseFileParser(BaseParser):
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10
serializer_cls = None
def check_content_length(self, meta):
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
if content_length > self.FILE_CONTENT_MAX_LENGTH:
msg = FileContentOverflowedError.default_detail.format(self.FILE_CONTENT_MAX_LENGTH)
logger.error(msg)
raise FileContentOverflowedError(msg)
@staticmethod
def get_stream_data(stream):
stream_data = stream.read()
stream_data = stream_data.strip(codecs.BOM_UTF8)
return stream_data
@abc.abstractmethod
def generate_rows(self, stream_data):
raise NotImplemented
def get_column_titles(self, rows):
return next(rows)
def convert_to_field_names(self, column_titles):
fields_map = {}
fields = self.serializer_cls().fields
fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()})
field_names = [
fields_map.get(column_title.strip('*'), '')
for column_title in column_titles
]
return field_names
@staticmethod
def _replace_chinese_quote(s):
trans_table = str.maketrans({
'': '"',
'': '"',
'': '"',
'': '"',
'\'': '"'
})
return s.translate(trans_table)
@classmethod
def process_row(cls, row):
"""
构建json数据前的行处理
"""
new_row = []
for col in row:
# 转换中文引号
col = cls._replace_chinese_quote(col)
# 列表/字典转换
if isinstance(col, str) and (
(col.startswith('[') and col.endswith(']'))
or
(col.startswith("{") and col.endswith("}"))
):
col = json.loads(col)
new_row.append(col)
return new_row
@staticmethod
def process_row_data(row_data):
"""
构建json数据后的行数据处理
"""
new_row_data = {}
for k, v in row_data.items():
if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip():
new_row_data[k] = v
return new_row_data
def generate_data(self, fields_name, rows):
data = []
for row in rows:
# 空行不处理
if not any(row):
continue
row = self.process_row(row)
row_data = dict(zip(fields_name, row))
row_data = self.process_row_data(row_data)
data.append(row_data)
return data
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
try:
view = parser_context['view']
meta = view.request.META
self.serializer_cls = view.get_serializer_class()
except Exception as e:
logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!')
self.check_content_length(meta)
try:
stream_data = self.get_stream_data(stream)
rows = self.generate_rows(stream_data)
column_titles = self.get_column_titles(rows)
field_names = self.convert_to_field_names(column_titles)
data = self.generate_data(field_names, rows)
return data
except Exception as e:
logger.error(e, exc_info=True)
raise ParseError('Parse error! ({})'.format(self.media_type))

View File

@ -1,32 +1,13 @@
# ~*~ coding: utf-8 ~*~
#
import json
import chardet
import codecs
import unicodecsv
from django.utils.translation import ugettext as _
from rest_framework.parsers import BaseParser
from rest_framework.exceptions import ParseError, APIException
from rest_framework import status
from common.utils import get_logger
logger = get_logger(__file__)
from .base import BaseFileParser
class CsvDataTooBig(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_code = 'csv_data_too_big'
default_detail = _('The max size of CSV is %d bytes')
class JMSCSVParser(BaseParser):
"""
Parses CSV file to serializer data
"""
CSV_UPLOAD_MAX_SIZE = 1024 * 1024 * 10
class CSVFileParser(BaseFileParser):
media_type = 'text/csv'
@ -38,99 +19,10 @@ class JMSCSVParser(BaseParser):
for line in stream.splitlines():
yield line
@staticmethod
def _gen_rows(csv_data, charset='utf-8', **kwargs):
csv_reader = unicodecsv.reader(csv_data, encoding=charset, **kwargs)
def generate_rows(self, stream_data):
detect_result = chardet.detect(stream_data)
encoding = detect_result.get("encoding", "utf-8")
lines = self._universal_newlines(stream_data)
csv_reader = unicodecsv.reader(lines, encoding=encoding)
for row in csv_reader:
if not any(row): # 空行
continue
yield row
@staticmethod
def _get_fields_map(serializer_cls):
fields_map = {}
fields = serializer_cls().fields
fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()})
return fields_map
@staticmethod
def _replace_chinese_quot(str_):
trans_table = str.maketrans({
'': '"',
'': '"',
'': '"',
'': '"',
'\'': '"'
})
return str_.translate(trans_table)
@classmethod
def _process_row(cls, row):
"""
构建json数据前的行处理
"""
_row = []
for col in row:
# 列表转换
if isinstance(col, str) and col.startswith('[') and col.endswith(']'):
col = cls._replace_chinese_quot(col)
col = json.loads(col)
# 字典转换
if isinstance(col, str) and col.startswith("{") and col.endswith("}"):
col = cls._replace_chinese_quot(col)
col = json.loads(col)
_row.append(col)
return _row
@staticmethod
def _process_row_data(row_data):
"""
构建json数据后的行数据处理
"""
_row_data = {}
for k, v in row_data.items():
if isinstance(v, list) or isinstance(v, dict)\
or isinstance(v, str) and k.strip() and v.strip():
_row_data[k] = v
return _row_data
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
try:
view = parser_context['view']
meta = view.request.META
serializer_cls = view.get_serializer_class()
except Exception as e:
logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!')
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
if content_length > self.CSV_UPLOAD_MAX_SIZE:
msg = CsvDataTooBig.default_detail % self.CSV_UPLOAD_MAX_SIZE
logger.error(msg)
raise CsvDataTooBig(msg)
try:
stream_data = stream.read()
stream_data = stream_data.strip(codecs.BOM_UTF8)
detect_result = chardet.detect(stream_data)
encoding = detect_result.get("encoding", "utf-8")
binary = self._universal_newlines(stream_data)
rows = self._gen_rows(binary, charset=encoding)
header = next(rows)
fields_map = self._get_fields_map(serializer_cls)
header = [fields_map.get(name.strip('*'), '') for name in header]
data = []
for row in rows:
row = self._process_row(row)
row_data = dict(zip(header, row))
row_data = self._process_row_data(row_data)
data.append(row_data)
return data
except Exception as e:
logger.error(e, exc_info=True)
raise ParseError('CSV parse error!')

View File

@ -0,0 +1,14 @@
import pyexcel
from .base import BaseFileParser
class ExcelFileParser(BaseFileParser):
media_type = 'text/xlsx'
def generate_rows(self, stream_data):
workbook = pyexcel.get_book(file_type='xlsx', file_content=stream_data)
# 默认获取第一个工作表sheet
sheet = workbook.sheet_by_index(0)
rows = sheet.rows()
return rows

View File

@ -1,6 +1,7 @@
from rest_framework import renderers
from .csv import *
from .excel import *
class PassthroughRenderer(renderers.BaseRenderer):

View File

@ -0,0 +1,132 @@
import abc
from datetime import datetime
from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json
from common.utils import get_logger
logger = get_logger(__file__)
class BaseFileRenderer(BaseRenderer):
# 渲染模版标识, 导入、导出、更新模版: ['import', 'update', 'export']
template = 'export'
serializer = None
@staticmethod
def _check_validation_data(data):
detail_key = "detail"
if detail_key in data:
return False
return True
@staticmethod
def _json_format_response(response_data):
return json.dumps(response_data)
def set_response_disposition(self, response):
serializer = self.serializer
if response and hasattr(serializer, 'Meta') and hasattr(serializer.Meta, "model"):
model_name = serializer.Meta.model.__name__.lower()
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = "{}_{}.{}".format(model_name, now, self.format)
disposition = 'attachment; filename="{}"'.format(filename)
response['Content-Disposition'] = disposition
def get_rendered_fields(self):
fields = self.serializer.fields
if self.template == 'import':
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
elif self.template == 'update':
return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
else:
return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
@staticmethod
def get_column_titles(render_fields):
return [
'*{}'.format(field.label) if field.required else str(field.label)
for field in render_fields
]
def process_data(self, data):
results = data['results'] if 'results' in data else data
if isinstance(results, dict):
results = [results]
if self.template == 'import':
results = [results[0]] if results else results
else:
# 限制数据数量
results = results[:10000]
# 会将一些 UUID 字段转化为 string
results = json.loads(json.dumps(results, cls=encoders.JSONEncoder))
return results
@staticmethod
def generate_rows(data, render_fields):
for item in data:
row = []
for field in render_fields:
value = item.get(field.field_name)
value = str(value) if value else ''
row.append(value)
yield row
@abc.abstractmethod
def initial_writer(self):
raise NotImplementedError
def write_column_titles(self, column_titles):
self.write_row(column_titles)
def write_rows(self, rows):
for row in rows:
self.write_row(row)
@abc.abstractmethod
def write_row(self, row):
raise NotImplementedError
@abc.abstractmethod
def get_rendered_value(self):
raise NotImplementedError
def render(self, data, accepted_media_type=None, renderer_context=None):
if data is None:
return bytes()
if not self._check_validation_data(data):
return self._json_format_response(data)
try:
renderer_context = renderer_context or {}
request = renderer_context['request']
response = renderer_context['response']
view = renderer_context['view']
self.template = request.query_params.get('template', 'export')
self.serializer = view.get_serializer()
self.set_response_disposition(response)
except Exception as e:
logger.debug(e, exc_info=True)
value = 'The resource not support export!'.encode('utf-8')
return value
try:
rendered_fields = self.get_rendered_fields()
column_titles = self.get_column_titles(rendered_fields)
data = self.process_data(data)
rows = self.generate_rows(data, rendered_fields)
self.initial_writer()
self.write_column_titles(column_titles)
self.write_rows(rows)
value = self.get_rendered_value()
except Exception as e:
logger.debug(e, exc_info=True)
value = 'Render error! ({})'.format(self.media_type).encode('utf-8')
return value
return value

View File

@ -1,83 +1,30 @@
# ~*~ coding: utf-8 ~*~
#
import unicodecsv
import codecs
from datetime import datetime
import unicodecsv
from six import BytesIO
from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json
from common.utils import get_logger
logger = get_logger(__file__)
from .base import BaseFileRenderer
class JMSCSVRender(BaseRenderer):
class CSVFileRenderer(BaseFileRenderer):
media_type = 'text/csv'
format = 'csv'
@staticmethod
def _get_show_fields(fields, template):
if template == 'import':
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
elif template == 'update':
return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
else:
return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
writer = None
buffer = None
@staticmethod
def _gen_table(data, fields):
data = data[:10000]
yield ['*{}'.format(f.label) if f.required else f.label for f in fields]
def initial_writer(self):
csv_buffer = BytesIO()
csv_buffer.write(codecs.BOM_UTF8)
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8')
self.buffer = csv_buffer
self.writer = csv_writer
for item in data:
row = [item.get(f.field_name) for f in fields]
yield row
def set_response_disposition(self, serializer, context):
response = context.get('response')
if response and hasattr(serializer, 'Meta') and \
hasattr(serializer.Meta, "model"):
model_name = serializer.Meta.model.__name__.lower()
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = "{}_{}.csv".format(model_name, now)
disposition = 'attachment; filename="{}"'.format(filename)
response['Content-Disposition'] = disposition
def render(self, data, media_type=None, renderer_context=None):
renderer_context = renderer_context or {}
request = renderer_context['request']
template = request.query_params.get('template', 'export')
view = renderer_context['view']
if isinstance(data, dict):
data = data.get("results", [])
if template == 'import':
data = [data[0]] if data else data
data = json.loads(json.dumps(data, cls=encoders.JSONEncoder))
try:
serializer = view.get_serializer()
self.set_response_disposition(serializer, renderer_context)
except Exception as e:
logger.debug(e, exc_info=True)
value = 'The resource not support export!'.encode('utf-8')
else:
fields = serializer.fields
show_fields = self._get_show_fields(fields, template)
table = self._gen_table(data, show_fields)
csv_buffer = BytesIO()
csv_buffer.write(codecs.BOM_UTF8)
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8')
for row in table:
csv_writer.writerow(row)
value = csv_buffer.getvalue()
def write_row(self, row):
self.writer.writerow(row)
def get_rendered_value(self):
value = self.buffer.getvalue()
return value

View File

@ -0,0 +1,28 @@
from openpyxl import Workbook
from openpyxl.writer.excel import save_virtual_workbook
from .base import BaseFileRenderer
class ExcelFileRenderer(BaseFileRenderer):
media_type = "application/xlsx"
format = "xlsx"
wb = None
ws = None
row_count = 0
def initial_writer(self):
self.wb = Workbook()
self.ws = self.wb.active
def write_row(self, row):
self.row_count += 1
column_count = 0
for cell_value in row:
column_count += 1
self.ws.cell(row=self.row_count, column=column_count, value=cell_value)
def get_rendered_value(self):
value = save_virtual_workbook(self.wb)
return value

View File

@ -12,13 +12,16 @@ REST_FRAMEWORK = {
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
# 'rest_framework.renderers.BrowsableAPIRenderer',
'common.drf.renders.JMSCSVRender',
'common.drf.renders.CSVFileRenderer',
'common.drf.renders.ExcelFileRenderer',
),
'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser',
'common.drf.parsers.JMSCSVParser',
'common.drf.parsers.CSVFileParser',
'common.drf.parsers.ExcelFileParser',
'rest_framework.parsers.FileUploadParser',
),
'DEFAULT_AUTHENTICATION_CLASSES': (

View File

@ -105,3 +105,6 @@ azure-mgmt-compute==4.6.2
azure-mgmt-network==2.7.0
msrestazure==0.6.4
adal==1.2.5
openpyxl==3.0.5
pyexcel==0.6.6
pyexcel-xlsx==0.6.0