From 43b5e97b95cb71fbb1370fb4cc175f140160f080 Mon Sep 17 00:00:00 2001 From: fit2bot <68588906+fit2bot@users.noreply.github.com> Date: Mon, 7 Dec 2020 15:23:05 +0800 Subject: [PATCH] =?UTF-8?q?feat(excel):=20=E6=B7=BB=E5=8A=A0Excel=E5=AF=BC?= =?UTF-8?q?=E5=85=A5/=E5=AF=BC=E5=87=BA=20(#5124)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(drf_renderer): 添加 ExcelRenderer 支持导出excel文件格式; 优化CSVRenderer, 抽象 BaseRenderer * perf(renderer): 支持导出资源详情 * refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser * refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser 2 * perf(renderer): 捕获renderer处理异常 * perf: 添加excel依赖包 * perf(drf): 优化导入导出错误日志 * perf: 添加依赖包 pyexcel-io==0.6.4 * perf: 添加依赖包pyexcel-xlsx==0.6.0 * feat: 修改drf/renderer&parser变量命名 * feat: 修改drf/renderer的bug * feat: 修改drf/renderer&parser变量命名 Co-authored-by: Bai --- apps/common/drf/parsers/__init__.py | 3 +- apps/common/drf/parsers/base.py | 132 ++++++++++++++++++++++++++++ apps/common/drf/parsers/csv.py | 122 ++----------------------- apps/common/drf/parsers/excel.py | 14 +++ apps/common/drf/renders/__init__.py | 1 + apps/common/drf/renders/base.py | 132 ++++++++++++++++++++++++++++ apps/common/drf/renders/csv.py | 83 ++++------------- apps/common/drf/renders/excel.py | 28 ++++++ apps/jumpserver/settings/libs.py | 7 +- requirements/requirements.txt | 3 + 10 files changed, 339 insertions(+), 186 deletions(-) create mode 100644 apps/common/drf/parsers/base.py create mode 100644 apps/common/drf/parsers/excel.py create mode 100644 apps/common/drf/renders/base.py create mode 100644 apps/common/drf/renders/excel.py diff --git a/apps/common/drf/parsers/__init__.py b/apps/common/drf/parsers/__init__.py index 671c86586..75dc28249 100644 --- a/apps/common/drf/parsers/__init__.py +++ b/apps/common/drf/parsers/__init__.py @@ -1 +1,2 @@ -from .csv import * \ No newline at end of file +from .csv import * +from .excel import * \ No newline at end of file diff --git a/apps/common/drf/parsers/base.py b/apps/common/drf/parsers/base.py new file mode 100644 index 000000000..605dcdd08 --- /dev/null +++ b/apps/common/drf/parsers/base.py @@ -0,0 +1,132 @@ +import abc +import json +import codecs +from django.utils.translation import ugettext_lazy as _ +from rest_framework.parsers import BaseParser +from rest_framework import status +from rest_framework.exceptions import ParseError, APIException +from common.utils import get_logger + +logger = get_logger(__file__) + + +class FileContentOverflowedError(APIException): + status_code = status.HTTP_400_BAD_REQUEST + default_code = 'file_content_overflowed' + default_detail = _('The file content overflowed (The maximum length `{}` bytes)') + + +class BaseFileParser(BaseParser): + + FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10 + + serializer_cls = None + + def check_content_length(self, meta): + content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))) + if content_length > self.FILE_CONTENT_MAX_LENGTH: + msg = FileContentOverflowedError.default_detail.format(self.FILE_CONTENT_MAX_LENGTH) + logger.error(msg) + raise FileContentOverflowedError(msg) + + @staticmethod + def get_stream_data(stream): + stream_data = stream.read() + stream_data = stream_data.strip(codecs.BOM_UTF8) + return stream_data + + @abc.abstractmethod + def generate_rows(self, stream_data): + raise NotImplemented + + def get_column_titles(self, rows): + return next(rows) + + def convert_to_field_names(self, column_titles): + fields_map = {} + fields = self.serializer_cls().fields + fields_map.update({v.label: k for k, v in fields.items()}) + fields_map.update({k: k for k, _ in fields.items()}) + field_names = [ + fields_map.get(column_title.strip('*'), '') + for column_title in column_titles + ] + return field_names + + @staticmethod + def _replace_chinese_quote(s): + trans_table = str.maketrans({ + '“': '"', + '”': '"', + '‘': '"', + '’': '"', + '\'': '"' + }) + return s.translate(trans_table) + + @classmethod + def process_row(cls, row): + """ + 构建json数据前的行处理 + """ + new_row = [] + for col in row: + # 转换中文引号 + col = cls._replace_chinese_quote(col) + # 列表/字典转换 + if isinstance(col, str) and ( + (col.startswith('[') and col.endswith(']')) + or + (col.startswith("{") and col.endswith("}")) + ): + col = json.loads(col) + new_row.append(col) + return new_row + + @staticmethod + def process_row_data(row_data): + """ + 构建json数据后的行数据处理 + """ + new_row_data = {} + for k, v in row_data.items(): + if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip(): + new_row_data[k] = v + return new_row_data + + def generate_data(self, fields_name, rows): + data = [] + for row in rows: + # 空行不处理 + if not any(row): + continue + row = self.process_row(row) + row_data = dict(zip(fields_name, row)) + row_data = self.process_row_data(row_data) + data.append(row_data) + return data + + def parse(self, stream, media_type=None, parser_context=None): + parser_context = parser_context or {} + + try: + view = parser_context['view'] + meta = view.request.META + self.serializer_cls = view.get_serializer_class() + except Exception as e: + logger.debug(e, exc_info=True) + raise ParseError('The resource does not support imports!') + + self.check_content_length(meta) + + try: + stream_data = self.get_stream_data(stream) + rows = self.generate_rows(stream_data) + column_titles = self.get_column_titles(rows) + field_names = self.convert_to_field_names(column_titles) + data = self.generate_data(field_names, rows) + return data + except Exception as e: + logger.error(e, exc_info=True) + raise ParseError('Parse error! ({})'.format(self.media_type)) + diff --git a/apps/common/drf/parsers/csv.py b/apps/common/drf/parsers/csv.py index de0d14ea7..0dd11aa4b 100644 --- a/apps/common/drf/parsers/csv.py +++ b/apps/common/drf/parsers/csv.py @@ -1,32 +1,13 @@ # ~*~ coding: utf-8 ~*~ # -import json import chardet -import codecs import unicodecsv -from django.utils.translation import ugettext as _ -from rest_framework.parsers import BaseParser -from rest_framework.exceptions import ParseError, APIException -from rest_framework import status - -from common.utils import get_logger - -logger = get_logger(__file__) +from .base import BaseFileParser -class CsvDataTooBig(APIException): - status_code = status.HTTP_400_BAD_REQUEST - default_code = 'csv_data_too_big' - default_detail = _('The max size of CSV is %d bytes') - - -class JMSCSVParser(BaseParser): - """ - Parses CSV file to serializer data - """ - CSV_UPLOAD_MAX_SIZE = 1024 * 1024 * 10 +class CSVFileParser(BaseFileParser): media_type = 'text/csv' @@ -38,99 +19,10 @@ class JMSCSVParser(BaseParser): for line in stream.splitlines(): yield line - @staticmethod - def _gen_rows(csv_data, charset='utf-8', **kwargs): - csv_reader = unicodecsv.reader(csv_data, encoding=charset, **kwargs) + def generate_rows(self, stream_data): + detect_result = chardet.detect(stream_data) + encoding = detect_result.get("encoding", "utf-8") + lines = self._universal_newlines(stream_data) + csv_reader = unicodecsv.reader(lines, encoding=encoding) for row in csv_reader: - if not any(row): # 空行 - continue yield row - - @staticmethod - def _get_fields_map(serializer_cls): - fields_map = {} - fields = serializer_cls().fields - fields_map.update({v.label: k for k, v in fields.items()}) - fields_map.update({k: k for k, _ in fields.items()}) - return fields_map - - @staticmethod - def _replace_chinese_quot(str_): - trans_table = str.maketrans({ - '“': '"', - '”': '"', - '‘': '"', - '’': '"', - '\'': '"' - }) - return str_.translate(trans_table) - - @classmethod - def _process_row(cls, row): - """ - 构建json数据前的行处理 - """ - _row = [] - - for col in row: - # 列表转换 - if isinstance(col, str) and col.startswith('[') and col.endswith(']'): - col = cls._replace_chinese_quot(col) - col = json.loads(col) - # 字典转换 - if isinstance(col, str) and col.startswith("{") and col.endswith("}"): - col = cls._replace_chinese_quot(col) - col = json.loads(col) - _row.append(col) - return _row - - @staticmethod - def _process_row_data(row_data): - """ - 构建json数据后的行数据处理 - """ - _row_data = {} - for k, v in row_data.items(): - if isinstance(v, list) or isinstance(v, dict)\ - or isinstance(v, str) and k.strip() and v.strip(): - _row_data[k] = v - return _row_data - - def parse(self, stream, media_type=None, parser_context=None): - parser_context = parser_context or {} - try: - view = parser_context['view'] - meta = view.request.META - serializer_cls = view.get_serializer_class() - except Exception as e: - logger.debug(e, exc_info=True) - raise ParseError('The resource does not support imports!') - - content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))) - if content_length > self.CSV_UPLOAD_MAX_SIZE: - msg = CsvDataTooBig.default_detail % self.CSV_UPLOAD_MAX_SIZE - logger.error(msg) - raise CsvDataTooBig(msg) - - try: - stream_data = stream.read() - stream_data = stream_data.strip(codecs.BOM_UTF8) - detect_result = chardet.detect(stream_data) - encoding = detect_result.get("encoding", "utf-8") - binary = self._universal_newlines(stream_data) - rows = self._gen_rows(binary, charset=encoding) - - header = next(rows) - fields_map = self._get_fields_map(serializer_cls) - header = [fields_map.get(name.strip('*'), '') for name in header] - - data = [] - for row in rows: - row = self._process_row(row) - row_data = dict(zip(header, row)) - row_data = self._process_row_data(row_data) - data.append(row_data) - return data - except Exception as e: - logger.error(e, exc_info=True) - raise ParseError('CSV parse error!') diff --git a/apps/common/drf/parsers/excel.py b/apps/common/drf/parsers/excel.py new file mode 100644 index 000000000..c5007866c --- /dev/null +++ b/apps/common/drf/parsers/excel.py @@ -0,0 +1,14 @@ +import pyexcel +from .base import BaseFileParser + + +class ExcelFileParser(BaseFileParser): + + media_type = 'text/xlsx' + + def generate_rows(self, stream_data): + workbook = pyexcel.get_book(file_type='xlsx', file_content=stream_data) + # 默认获取第一个工作表sheet + sheet = workbook.sheet_by_index(0) + rows = sheet.rows() + return rows diff --git a/apps/common/drf/renders/__init__.py b/apps/common/drf/renders/__init__.py index f99b13586..bbefe8783 100644 --- a/apps/common/drf/renders/__init__.py +++ b/apps/common/drf/renders/__init__.py @@ -1,6 +1,7 @@ from rest_framework import renderers from .csv import * +from .excel import * class PassthroughRenderer(renderers.BaseRenderer): diff --git a/apps/common/drf/renders/base.py b/apps/common/drf/renders/base.py new file mode 100644 index 000000000..deac735cc --- /dev/null +++ b/apps/common/drf/renders/base.py @@ -0,0 +1,132 @@ +import abc +from datetime import datetime +from rest_framework.renderers import BaseRenderer +from rest_framework.utils import encoders, json + +from common.utils import get_logger + +logger = get_logger(__file__) + + +class BaseFileRenderer(BaseRenderer): + # 渲染模版标识, 导入、导出、更新模版: ['import', 'update', 'export'] + template = 'export' + serializer = None + + @staticmethod + def _check_validation_data(data): + detail_key = "detail" + if detail_key in data: + return False + return True + + @staticmethod + def _json_format_response(response_data): + return json.dumps(response_data) + + def set_response_disposition(self, response): + serializer = self.serializer + if response and hasattr(serializer, 'Meta') and hasattr(serializer.Meta, "model"): + model_name = serializer.Meta.model.__name__.lower() + now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + filename = "{}_{}.{}".format(model_name, now, self.format) + disposition = 'attachment; filename="{}"'.format(filename) + response['Content-Disposition'] = disposition + + def get_rendered_fields(self): + fields = self.serializer.fields + if self.template == 'import': + return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id'] + elif self.template == 'update': + return [v for k, v in fields.items() if not v.read_only and k != "org_id"] + else: + return [v for k, v in fields.items() if not v.write_only and k != "org_id"] + + @staticmethod + def get_column_titles(render_fields): + return [ + '*{}'.format(field.label) if field.required else str(field.label) + for field in render_fields + ] + + def process_data(self, data): + results = data['results'] if 'results' in data else data + + if isinstance(results, dict): + results = [results] + + if self.template == 'import': + results = [results[0]] if results else results + + else: + # 限制数据数量 + results = results[:10000] + # 会将一些 UUID 字段转化为 string + results = json.loads(json.dumps(results, cls=encoders.JSONEncoder)) + return results + + @staticmethod + def generate_rows(data, render_fields): + for item in data: + row = [] + for field in render_fields: + value = item.get(field.field_name) + value = str(value) if value else '' + row.append(value) + yield row + + @abc.abstractmethod + def initial_writer(self): + raise NotImplementedError + + def write_column_titles(self, column_titles): + self.write_row(column_titles) + + def write_rows(self, rows): + for row in rows: + self.write_row(row) + + @abc.abstractmethod + def write_row(self, row): + raise NotImplementedError + + @abc.abstractmethod + def get_rendered_value(self): + raise NotImplementedError + + def render(self, data, accepted_media_type=None, renderer_context=None): + if data is None: + return bytes() + + if not self._check_validation_data(data): + return self._json_format_response(data) + + try: + renderer_context = renderer_context or {} + request = renderer_context['request'] + response = renderer_context['response'] + view = renderer_context['view'] + self.template = request.query_params.get('template', 'export') + self.serializer = view.get_serializer() + self.set_response_disposition(response) + except Exception as e: + logger.debug(e, exc_info=True) + value = 'The resource not support export!'.encode('utf-8') + return value + + try: + rendered_fields = self.get_rendered_fields() + column_titles = self.get_column_titles(rendered_fields) + data = self.process_data(data) + rows = self.generate_rows(data, rendered_fields) + self.initial_writer() + self.write_column_titles(column_titles) + self.write_rows(rows) + value = self.get_rendered_value() + except Exception as e: + logger.debug(e, exc_info=True) + value = 'Render error! ({})'.format(self.media_type).encode('utf-8') + return value + + return value + diff --git a/apps/common/drf/renders/csv.py b/apps/common/drf/renders/csv.py index 435e3d4a6..ba469a21f 100644 --- a/apps/common/drf/renders/csv.py +++ b/apps/common/drf/renders/csv.py @@ -1,83 +1,30 @@ # ~*~ coding: utf-8 ~*~ # -import unicodecsv import codecs -from datetime import datetime - +import unicodecsv from six import BytesIO -from rest_framework.renderers import BaseRenderer -from rest_framework.utils import encoders, json -from common.utils import get_logger - -logger = get_logger(__file__) +from .base import BaseFileRenderer -class JMSCSVRender(BaseRenderer): - +class CSVFileRenderer(BaseFileRenderer): media_type = 'text/csv' format = 'csv' - @staticmethod - def _get_show_fields(fields, template): - if template == 'import': - return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id'] - elif template == 'update': - return [v for k, v in fields.items() if not v.read_only and k != "org_id"] - else: - return [v for k, v in fields.items() if not v.write_only and k != "org_id"] + writer = None + buffer = None - @staticmethod - def _gen_table(data, fields): - data = data[:10000] - yield ['*{}'.format(f.label) if f.required else f.label for f in fields] + def initial_writer(self): + csv_buffer = BytesIO() + csv_buffer.write(codecs.BOM_UTF8) + csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8') + self.buffer = csv_buffer + self.writer = csv_writer - for item in data: - row = [item.get(f.field_name) for f in fields] - yield row - - def set_response_disposition(self, serializer, context): - response = context.get('response') - if response and hasattr(serializer, 'Meta') and \ - hasattr(serializer.Meta, "model"): - model_name = serializer.Meta.model.__name__.lower() - now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - filename = "{}_{}.csv".format(model_name, now) - disposition = 'attachment; filename="{}"'.format(filename) - response['Content-Disposition'] = disposition - - def render(self, data, media_type=None, renderer_context=None): - renderer_context = renderer_context or {} - request = renderer_context['request'] - template = request.query_params.get('template', 'export') - view = renderer_context['view'] - - if isinstance(data, dict): - data = data.get("results", []) - - if template == 'import': - data = [data[0]] if data else data - - data = json.loads(json.dumps(data, cls=encoders.JSONEncoder)) - - try: - serializer = view.get_serializer() - self.set_response_disposition(serializer, renderer_context) - except Exception as e: - logger.debug(e, exc_info=True) - value = 'The resource not support export!'.encode('utf-8') - else: - fields = serializer.fields - show_fields = self._get_show_fields(fields, template) - table = self._gen_table(data, show_fields) - - csv_buffer = BytesIO() - csv_buffer.write(codecs.BOM_UTF8) - csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8') - for row in table: - csv_writer.writerow(row) - - value = csv_buffer.getvalue() + def write_row(self, row): + self.writer.writerow(row) + def get_rendered_value(self): + value = self.buffer.getvalue() return value diff --git a/apps/common/drf/renders/excel.py b/apps/common/drf/renders/excel.py new file mode 100644 index 000000000..0d1cb8d51 --- /dev/null +++ b/apps/common/drf/renders/excel.py @@ -0,0 +1,28 @@ +from openpyxl import Workbook +from openpyxl.writer.excel import save_virtual_workbook + +from .base import BaseFileRenderer + + +class ExcelFileRenderer(BaseFileRenderer): + media_type = "application/xlsx" + format = "xlsx" + + wb = None + ws = None + row_count = 0 + + def initial_writer(self): + self.wb = Workbook() + self.ws = self.wb.active + + def write_row(self, row): + self.row_count += 1 + column_count = 0 + for cell_value in row: + column_count += 1 + self.ws.cell(row=self.row_count, column=column_count, value=cell_value) + + def get_rendered_value(self): + value = save_virtual_workbook(self.wb) + return value diff --git a/apps/jumpserver/settings/libs.py b/apps/jumpserver/settings/libs.py index e60932464..782d2bc06 100644 --- a/apps/jumpserver/settings/libs.py +++ b/apps/jumpserver/settings/libs.py @@ -12,13 +12,16 @@ REST_FRAMEWORK = { 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', # 'rest_framework.renderers.BrowsableAPIRenderer', - 'common.drf.renders.JMSCSVRender', + 'common.drf.renders.CSVFileRenderer', + 'common.drf.renders.ExcelFileRenderer', + ), 'DEFAULT_PARSER_CLASSES': ( 'rest_framework.parsers.JSONParser', 'rest_framework.parsers.FormParser', 'rest_framework.parsers.MultiPartParser', - 'common.drf.parsers.JMSCSVParser', + 'common.drf.parsers.CSVFileParser', + 'common.drf.parsers.ExcelFileParser', 'rest_framework.parsers.FileUploadParser', ), 'DEFAULT_AUTHENTICATION_CLASSES': ( diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 391481090..fa287298e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -105,3 +105,6 @@ azure-mgmt-compute==4.6.2 azure-mgmt-network==2.7.0 msrestazure==0.6.4 adal==1.2.5 +openpyxl==3.0.5 +pyexcel==0.6.6 +pyexcel-xlsx==0.6.0