mirror of https://github.com/jumpserver/jumpserver
Browse Source
* refactor(drf_renderer): 添加 ExcelRenderer 支持导出excel文件格式; 优化CSVRenderer, 抽象 BaseRenderer * perf(renderer): 支持导出资源详情 * refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser * refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser 2 * perf(renderer): 捕获renderer处理异常 * perf: 添加excel依赖包 * perf(drf): 优化导入导出错误日志 * perf: 添加依赖包 pyexcel-io==0.6.4 * perf: 添加依赖包pyexcel-xlsx==0.6.0 * feat: 修改drf/renderer&parser变量命名 * feat: 修改drf/renderer的bug * feat: 修改drf/renderer&parser变量命名 Co-authored-by: Bai <bugatti_it@163.com>pull/5177/head
fit2bot
4 years ago
committed by
GitHub
10 changed files with 339 additions and 186 deletions
@ -1 +1,2 @@
|
||||
from .csv import * |
||||
from .csv import * |
||||
from .excel import * |
@ -0,0 +1,132 @@
|
||||
import abc |
||||
import json |
||||
import codecs |
||||
from django.utils.translation import ugettext_lazy as _ |
||||
from rest_framework.parsers import BaseParser |
||||
from rest_framework import status |
||||
from rest_framework.exceptions import ParseError, APIException |
||||
from common.utils import get_logger |
||||
|
||||
logger = get_logger(__file__) |
||||
|
||||
|
||||
class FileContentOverflowedError(APIException): |
||||
status_code = status.HTTP_400_BAD_REQUEST |
||||
default_code = 'file_content_overflowed' |
||||
default_detail = _('The file content overflowed (The maximum length `{}` bytes)') |
||||
|
||||
|
||||
class BaseFileParser(BaseParser): |
||||
|
||||
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10 |
||||
|
||||
serializer_cls = None |
||||
|
||||
def check_content_length(self, meta): |
||||
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))) |
||||
if content_length > self.FILE_CONTENT_MAX_LENGTH: |
||||
msg = FileContentOverflowedError.default_detail.format(self.FILE_CONTENT_MAX_LENGTH) |
||||
logger.error(msg) |
||||
raise FileContentOverflowedError(msg) |
||||
|
||||
@staticmethod |
||||
def get_stream_data(stream): |
||||
stream_data = stream.read() |
||||
stream_data = stream_data.strip(codecs.BOM_UTF8) |
||||
return stream_data |
||||
|
||||
@abc.abstractmethod |
||||
def generate_rows(self, stream_data): |
||||
raise NotImplemented |
||||
|
||||
def get_column_titles(self, rows): |
||||
return next(rows) |
||||
|
||||
def convert_to_field_names(self, column_titles): |
||||
fields_map = {} |
||||
fields = self.serializer_cls().fields |
||||
fields_map.update({v.label: k for k, v in fields.items()}) |
||||
fields_map.update({k: k for k, _ in fields.items()}) |
||||
field_names = [ |
||||
fields_map.get(column_title.strip('*'), '') |
||||
for column_title in column_titles |
||||
] |
||||
return field_names |
||||
|
||||
@staticmethod |
||||
def _replace_chinese_quote(s): |
||||
trans_table = str.maketrans({ |
||||
'“': '"', |
||||
'”': '"', |
||||
'‘': '"', |
||||
'’': '"', |
||||
'\'': '"' |
||||
}) |
||||
return s.translate(trans_table) |
||||
|
||||
@classmethod |
||||
def process_row(cls, row): |
||||
""" |
||||
构建json数据前的行处理 |
||||
""" |
||||
new_row = [] |
||||
for col in row: |
||||
# 转换中文引号 |
||||
col = cls._replace_chinese_quote(col) |
||||
# 列表/字典转换 |
||||
if isinstance(col, str) and ( |
||||
(col.startswith('[') and col.endswith(']')) |
||||
or |
||||
(col.startswith("{") and col.endswith("}")) |
||||
): |
||||
col = json.loads(col) |
||||
new_row.append(col) |
||||
return new_row |
||||
|
||||
@staticmethod |
||||
def process_row_data(row_data): |
||||
""" |
||||
构建json数据后的行数据处理 |
||||
""" |
||||
new_row_data = {} |
||||
for k, v in row_data.items(): |
||||
if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip(): |
||||
new_row_data[k] = v |
||||
return new_row_data |
||||
|
||||
def generate_data(self, fields_name, rows): |
||||
data = [] |
||||
for row in rows: |
||||
# 空行不处理 |
||||
if not any(row): |
||||
continue |
||||
row = self.process_row(row) |
||||
row_data = dict(zip(fields_name, row)) |
||||
row_data = self.process_row_data(row_data) |
||||
data.append(row_data) |
||||
return data |
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None): |
||||
parser_context = parser_context or {} |
||||
|
||||
try: |
||||
view = parser_context['view'] |
||||
meta = view.request.META |
||||
self.serializer_cls = view.get_serializer_class() |
||||
except Exception as e: |
||||
logger.debug(e, exc_info=True) |
||||
raise ParseError('The resource does not support imports!') |
||||
|
||||
self.check_content_length(meta) |
||||
|
||||
try: |
||||
stream_data = self.get_stream_data(stream) |
||||
rows = self.generate_rows(stream_data) |
||||
column_titles = self.get_column_titles(rows) |
||||
field_names = self.convert_to_field_names(column_titles) |
||||
data = self.generate_data(field_names, rows) |
||||
return data |
||||
except Exception as e: |
||||
logger.error(e, exc_info=True) |
||||
raise ParseError('Parse error! ({})'.format(self.media_type)) |
||||
|
@ -0,0 +1,14 @@
|
||||
import pyexcel |
||||
from .base import BaseFileParser |
||||
|
||||
|
||||
class ExcelFileParser(BaseFileParser): |
||||
|
||||
media_type = 'text/xlsx' |
||||
|
||||
def generate_rows(self, stream_data): |
||||
workbook = pyexcel.get_book(file_type='xlsx', file_content=stream_data) |
||||
# 默认获取第一个工作表sheet |
||||
sheet = workbook.sheet_by_index(0) |
||||
rows = sheet.rows() |
||||
return rows |
@ -0,0 +1,132 @@
|
||||
import abc |
||||
from datetime import datetime |
||||
from rest_framework.renderers import BaseRenderer |
||||
from rest_framework.utils import encoders, json |
||||
|
||||
from common.utils import get_logger |
||||
|
||||
logger = get_logger(__file__) |
||||
|
||||
|
||||
class BaseFileRenderer(BaseRenderer): |
||||
# 渲染模版标识, 导入、导出、更新模版: ['import', 'update', 'export'] |
||||
template = 'export' |
||||
serializer = None |
||||
|
||||
@staticmethod |
||||
def _check_validation_data(data): |
||||
detail_key = "detail" |
||||
if detail_key in data: |
||||
return False |
||||
return True |
||||
|
||||
@staticmethod |
||||
def _json_format_response(response_data): |
||||
return json.dumps(response_data) |
||||
|
||||
def set_response_disposition(self, response): |
||||
serializer = self.serializer |
||||
if response and hasattr(serializer, 'Meta') and hasattr(serializer.Meta, "model"): |
||||
model_name = serializer.Meta.model.__name__.lower() |
||||
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
||||
filename = "{}_{}.{}".format(model_name, now, self.format) |
||||
disposition = 'attachment; filename="{}"'.format(filename) |
||||
response['Content-Disposition'] = disposition |
||||
|
||||
def get_rendered_fields(self): |
||||
fields = self.serializer.fields |
||||
if self.template == 'import': |
||||
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id'] |
||||
elif self.template == 'update': |
||||
return [v for k, v in fields.items() if not v.read_only and k != "org_id"] |
||||
else: |
||||
return [v for k, v in fields.items() if not v.write_only and k != "org_id"] |
||||
|
||||
@staticmethod |
||||
def get_column_titles(render_fields): |
||||
return [ |
||||
'*{}'.format(field.label) if field.required else str(field.label) |
||||
for field in render_fields |
||||
] |
||||
|
||||
def process_data(self, data): |
||||
results = data['results'] if 'results' in data else data |
||||
|
||||
if isinstance(results, dict): |
||||
results = [results] |
||||
|
||||
if self.template == 'import': |
||||
results = [results[0]] if results else results |
||||
|
||||
else: |
||||
# 限制数据数量 |
||||
results = results[:10000] |
||||
# 会将一些 UUID 字段转化为 string |
||||
results = json.loads(json.dumps(results, cls=encoders.JSONEncoder)) |
||||
return results |
||||
|
||||
@staticmethod |
||||
def generate_rows(data, render_fields): |
||||
for item in data: |
||||
row = [] |
||||
for field in render_fields: |
||||
value = item.get(field.field_name) |
||||
value = str(value) if value else '' |
||||
row.append(value) |
||||
yield row |
||||
|
||||
@abc.abstractmethod |
||||
def initial_writer(self): |
||||
raise NotImplementedError |
||||
|
||||
def write_column_titles(self, column_titles): |
||||
self.write_row(column_titles) |
||||
|
||||
def write_rows(self, rows): |
||||
for row in rows: |
||||
self.write_row(row) |
||||
|
||||
@abc.abstractmethod |
||||
def write_row(self, row): |
||||
raise NotImplementedError |
||||
|
||||
@abc.abstractmethod |
||||
def get_rendered_value(self): |
||||
raise NotImplementedError |
||||
|
||||
def render(self, data, accepted_media_type=None, renderer_context=None): |
||||
if data is None: |
||||
return bytes() |
||||
|
||||
if not self._check_validation_data(data): |
||||
return self._json_format_response(data) |
||||
|
||||
try: |
||||
renderer_context = renderer_context or {} |
||||
request = renderer_context['request'] |
||||
response = renderer_context['response'] |
||||
view = renderer_context['view'] |
||||
self.template = request.query_params.get('template', 'export') |
||||
self.serializer = view.get_serializer() |
||||
self.set_response_disposition(response) |
||||
except Exception as e: |
||||
logger.debug(e, exc_info=True) |
||||
value = 'The resource not support export!'.encode('utf-8') |
||||
return value |
||||
|
||||
try: |
||||
rendered_fields = self.get_rendered_fields() |
||||
column_titles = self.get_column_titles(rendered_fields) |
||||
data = self.process_data(data) |
||||
rows = self.generate_rows(data, rendered_fields) |
||||
self.initial_writer() |
||||
self.write_column_titles(column_titles) |
||||
self.write_rows(rows) |
||||
value = self.get_rendered_value() |
||||
except Exception as e: |
||||
logger.debug(e, exc_info=True) |
||||
value = 'Render error! ({})'.format(self.media_type).encode('utf-8') |
||||
return value |
||||
|
||||
return value |
||||
|
@ -1,83 +1,30 @@
|
||||
# ~*~ coding: utf-8 ~*~ |
||||
# |
||||
|
||||
import unicodecsv |
||||
import codecs |
||||
from datetime import datetime |
||||
|
||||
import unicodecsv |
||||
from six import BytesIO |
||||
from rest_framework.renderers import BaseRenderer |
||||
from rest_framework.utils import encoders, json |
||||
|
||||
from common.utils import get_logger |
||||
|
||||
logger = get_logger(__file__) |
||||
from .base import BaseFileRenderer |
||||
|
||||
|
||||
class JMSCSVRender(BaseRenderer): |
||||
|
||||
class CSVFileRenderer(BaseFileRenderer): |
||||
media_type = 'text/csv' |
||||
format = 'csv' |
||||
|
||||
@staticmethod |
||||
def _get_show_fields(fields, template): |
||||
if template == 'import': |
||||
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id'] |
||||
elif template == 'update': |
||||
return [v for k, v in fields.items() if not v.read_only and k != "org_id"] |
||||
else: |
||||
return [v for k, v in fields.items() if not v.write_only and k != "org_id"] |
||||
|
||||
@staticmethod |
||||
def _gen_table(data, fields): |
||||
data = data[:10000] |
||||
yield ['*{}'.format(f.label) if f.required else f.label for f in fields] |
||||
|
||||
for item in data: |
||||
row = [item.get(f.field_name) for f in fields] |
||||
yield row |
||||
|
||||
def set_response_disposition(self, serializer, context): |
||||
response = context.get('response') |
||||
if response and hasattr(serializer, 'Meta') and \ |
||||
hasattr(serializer.Meta, "model"): |
||||
model_name = serializer.Meta.model.__name__.lower() |
||||
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
||||
filename = "{}_{}.csv".format(model_name, now) |
||||
disposition = 'attachment; filename="{}"'.format(filename) |
||||
response['Content-Disposition'] = disposition |
||||
|
||||
def render(self, data, media_type=None, renderer_context=None): |
||||
renderer_context = renderer_context or {} |
||||
request = renderer_context['request'] |
||||
template = request.query_params.get('template', 'export') |
||||
view = renderer_context['view'] |
||||
|
||||
if isinstance(data, dict): |
||||
data = data.get("results", []) |
||||
|
||||
if template == 'import': |
||||
data = [data[0]] if data else data |
||||
|
||||
data = json.loads(json.dumps(data, cls=encoders.JSONEncoder)) |
||||
|
||||
try: |
||||
serializer = view.get_serializer() |
||||
self.set_response_disposition(serializer, renderer_context) |
||||
except Exception as e: |
||||
logger.debug(e, exc_info=True) |
||||
value = 'The resource not support export!'.encode('utf-8') |
||||
else: |
||||
fields = serializer.fields |
||||
show_fields = self._get_show_fields(fields, template) |
||||
table = self._gen_table(data, show_fields) |
||||
writer = None |
||||
buffer = None |
||||
|
||||
csv_buffer = BytesIO() |
||||
csv_buffer.write(codecs.BOM_UTF8) |
||||
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8') |
||||
for row in table: |
||||
csv_writer.writerow(row) |
||||
def initial_writer(self): |
||||
csv_buffer = BytesIO() |
||||
csv_buffer.write(codecs.BOM_UTF8) |
||||
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8') |
||||
self.buffer = csv_buffer |
||||
self.writer = csv_writer |
||||
|
||||
value = csv_buffer.getvalue() |
||||
def write_row(self, row): |
||||
self.writer.writerow(row) |
||||
|
||||
def get_rendered_value(self): |
||||
value = self.buffer.getvalue() |
||||
return value |
||||
|
@ -0,0 +1,28 @@
|
||||
from openpyxl import Workbook |
||||
from openpyxl.writer.excel import save_virtual_workbook |
||||
|
||||
from .base import BaseFileRenderer |
||||
|
||||
|
||||
class ExcelFileRenderer(BaseFileRenderer): |
||||
media_type = "application/xlsx" |
||||
format = "xlsx" |
||||
|
||||
wb = None |
||||
ws = None |
||||
row_count = 0 |
||||
|
||||
def initial_writer(self): |
||||
self.wb = Workbook() |
||||
self.ws = self.wb.active |
||||
|
||||
def write_row(self, row): |
||||
self.row_count += 1 |
||||
column_count = 0 |
||||
for cell_value in row: |
||||
column_count += 1 |
||||
self.ws.cell(row=self.row_count, column=column_count, value=cell_value) |
||||
|
||||
def get_rendered_value(self): |
||||
value = save_virtual_workbook(self.wb) |
||||
return value |
Loading…
Reference in new issue