mirror of https://github.com/jumpserver/jumpserver
feat(excel): 添加Excel导入/导出 (#5124)
* refactor(drf_renderer): 添加 ExcelRenderer 支持导出excel文件格式; 优化CSVRenderer, 抽象 BaseRenderer * perf(renderer): 支持导出资源详情 * refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser * refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser 2 * perf(renderer): 捕获renderer处理异常 * perf: 添加excel依赖包 * perf(drf): 优化导入导出错误日志 * perf: 添加依赖包 pyexcel-io==0.6.4 * perf: 添加依赖包pyexcel-xlsx==0.6.0 * feat: 修改drf/renderer&parser变量命名 * feat: 修改drf/renderer的bug * feat: 修改drf/renderer&parser变量命名 Co-authored-by: Bai <bugatti_it@163.com>pull/5177/head
parent
619b521ea1
commit
43b5e97b95
|
@ -1 +1,2 @@
|
|||
from .csv import *
|
||||
from .csv import *
|
||||
from .excel import *
|
|
@ -0,0 +1,132 @@
|
|||
import abc
|
||||
import json
|
||||
import codecs
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from rest_framework.parsers import BaseParser
|
||||
from rest_framework import status
|
||||
from rest_framework.exceptions import ParseError, APIException
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class FileContentOverflowedError(APIException):
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
default_code = 'file_content_overflowed'
|
||||
default_detail = _('The file content overflowed (The maximum length `{}` bytes)')
|
||||
|
||||
|
||||
class BaseFileParser(BaseParser):
|
||||
|
||||
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10
|
||||
|
||||
serializer_cls = None
|
||||
|
||||
def check_content_length(self, meta):
|
||||
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
|
||||
if content_length > self.FILE_CONTENT_MAX_LENGTH:
|
||||
msg = FileContentOverflowedError.default_detail.format(self.FILE_CONTENT_MAX_LENGTH)
|
||||
logger.error(msg)
|
||||
raise FileContentOverflowedError(msg)
|
||||
|
||||
@staticmethod
|
||||
def get_stream_data(stream):
|
||||
stream_data = stream.read()
|
||||
stream_data = stream_data.strip(codecs.BOM_UTF8)
|
||||
return stream_data
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_rows(self, stream_data):
|
||||
raise NotImplemented
|
||||
|
||||
def get_column_titles(self, rows):
|
||||
return next(rows)
|
||||
|
||||
def convert_to_field_names(self, column_titles):
|
||||
fields_map = {}
|
||||
fields = self.serializer_cls().fields
|
||||
fields_map.update({v.label: k for k, v in fields.items()})
|
||||
fields_map.update({k: k for k, _ in fields.items()})
|
||||
field_names = [
|
||||
fields_map.get(column_title.strip('*'), '')
|
||||
for column_title in column_titles
|
||||
]
|
||||
return field_names
|
||||
|
||||
@staticmethod
|
||||
def _replace_chinese_quote(s):
|
||||
trans_table = str.maketrans({
|
||||
'“': '"',
|
||||
'”': '"',
|
||||
'‘': '"',
|
||||
'’': '"',
|
||||
'\'': '"'
|
||||
})
|
||||
return s.translate(trans_table)
|
||||
|
||||
@classmethod
|
||||
def process_row(cls, row):
|
||||
"""
|
||||
构建json数据前的行处理
|
||||
"""
|
||||
new_row = []
|
||||
for col in row:
|
||||
# 转换中文引号
|
||||
col = cls._replace_chinese_quote(col)
|
||||
# 列表/字典转换
|
||||
if isinstance(col, str) and (
|
||||
(col.startswith('[') and col.endswith(']'))
|
||||
or
|
||||
(col.startswith("{") and col.endswith("}"))
|
||||
):
|
||||
col = json.loads(col)
|
||||
new_row.append(col)
|
||||
return new_row
|
||||
|
||||
@staticmethod
|
||||
def process_row_data(row_data):
|
||||
"""
|
||||
构建json数据后的行数据处理
|
||||
"""
|
||||
new_row_data = {}
|
||||
for k, v in row_data.items():
|
||||
if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip():
|
||||
new_row_data[k] = v
|
||||
return new_row_data
|
||||
|
||||
def generate_data(self, fields_name, rows):
|
||||
data = []
|
||||
for row in rows:
|
||||
# 空行不处理
|
||||
if not any(row):
|
||||
continue
|
||||
row = self.process_row(row)
|
||||
row_data = dict(zip(fields_name, row))
|
||||
row_data = self.process_row_data(row_data)
|
||||
data.append(row_data)
|
||||
return data
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None):
|
||||
parser_context = parser_context or {}
|
||||
|
||||
try:
|
||||
view = parser_context['view']
|
||||
meta = view.request.META
|
||||
self.serializer_cls = view.get_serializer_class()
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
raise ParseError('The resource does not support imports!')
|
||||
|
||||
self.check_content_length(meta)
|
||||
|
||||
try:
|
||||
stream_data = self.get_stream_data(stream)
|
||||
rows = self.generate_rows(stream_data)
|
||||
column_titles = self.get_column_titles(rows)
|
||||
field_names = self.convert_to_field_names(column_titles)
|
||||
data = self.generate_data(field_names, rows)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise ParseError('Parse error! ({})'.format(self.media_type))
|
||||
|
|
@ -1,32 +1,13 @@
|
|||
# ~*~ coding: utf-8 ~*~
|
||||
#
|
||||
|
||||
import json
|
||||
import chardet
|
||||
import codecs
|
||||
import unicodecsv
|
||||
|
||||
from django.utils.translation import ugettext as _
|
||||
from rest_framework.parsers import BaseParser
|
||||
from rest_framework.exceptions import ParseError, APIException
|
||||
from rest_framework import status
|
||||
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
from .base import BaseFileParser
|
||||
|
||||
|
||||
class CsvDataTooBig(APIException):
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
default_code = 'csv_data_too_big'
|
||||
default_detail = _('The max size of CSV is %d bytes')
|
||||
|
||||
|
||||
class JMSCSVParser(BaseParser):
|
||||
"""
|
||||
Parses CSV file to serializer data
|
||||
"""
|
||||
CSV_UPLOAD_MAX_SIZE = 1024 * 1024 * 10
|
||||
class CSVFileParser(BaseFileParser):
|
||||
|
||||
media_type = 'text/csv'
|
||||
|
||||
|
@ -38,99 +19,10 @@ class JMSCSVParser(BaseParser):
|
|||
for line in stream.splitlines():
|
||||
yield line
|
||||
|
||||
@staticmethod
|
||||
def _gen_rows(csv_data, charset='utf-8', **kwargs):
|
||||
csv_reader = unicodecsv.reader(csv_data, encoding=charset, **kwargs)
|
||||
def generate_rows(self, stream_data):
|
||||
detect_result = chardet.detect(stream_data)
|
||||
encoding = detect_result.get("encoding", "utf-8")
|
||||
lines = self._universal_newlines(stream_data)
|
||||
csv_reader = unicodecsv.reader(lines, encoding=encoding)
|
||||
for row in csv_reader:
|
||||
if not any(row): # 空行
|
||||
continue
|
||||
yield row
|
||||
|
||||
@staticmethod
|
||||
def _get_fields_map(serializer_cls):
|
||||
fields_map = {}
|
||||
fields = serializer_cls().fields
|
||||
fields_map.update({v.label: k for k, v in fields.items()})
|
||||
fields_map.update({k: k for k, _ in fields.items()})
|
||||
return fields_map
|
||||
|
||||
@staticmethod
|
||||
def _replace_chinese_quot(str_):
|
||||
trans_table = str.maketrans({
|
||||
'“': '"',
|
||||
'”': '"',
|
||||
'‘': '"',
|
||||
'’': '"',
|
||||
'\'': '"'
|
||||
})
|
||||
return str_.translate(trans_table)
|
||||
|
||||
@classmethod
|
||||
def _process_row(cls, row):
|
||||
"""
|
||||
构建json数据前的行处理
|
||||
"""
|
||||
_row = []
|
||||
|
||||
for col in row:
|
||||
# 列表转换
|
||||
if isinstance(col, str) and col.startswith('[') and col.endswith(']'):
|
||||
col = cls._replace_chinese_quot(col)
|
||||
col = json.loads(col)
|
||||
# 字典转换
|
||||
if isinstance(col, str) and col.startswith("{") and col.endswith("}"):
|
||||
col = cls._replace_chinese_quot(col)
|
||||
col = json.loads(col)
|
||||
_row.append(col)
|
||||
return _row
|
||||
|
||||
@staticmethod
|
||||
def _process_row_data(row_data):
|
||||
"""
|
||||
构建json数据后的行数据处理
|
||||
"""
|
||||
_row_data = {}
|
||||
for k, v in row_data.items():
|
||||
if isinstance(v, list) or isinstance(v, dict)\
|
||||
or isinstance(v, str) and k.strip() and v.strip():
|
||||
_row_data[k] = v
|
||||
return _row_data
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None):
|
||||
parser_context = parser_context or {}
|
||||
try:
|
||||
view = parser_context['view']
|
||||
meta = view.request.META
|
||||
serializer_cls = view.get_serializer_class()
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
raise ParseError('The resource does not support imports!')
|
||||
|
||||
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
|
||||
if content_length > self.CSV_UPLOAD_MAX_SIZE:
|
||||
msg = CsvDataTooBig.default_detail % self.CSV_UPLOAD_MAX_SIZE
|
||||
logger.error(msg)
|
||||
raise CsvDataTooBig(msg)
|
||||
|
||||
try:
|
||||
stream_data = stream.read()
|
||||
stream_data = stream_data.strip(codecs.BOM_UTF8)
|
||||
detect_result = chardet.detect(stream_data)
|
||||
encoding = detect_result.get("encoding", "utf-8")
|
||||
binary = self._universal_newlines(stream_data)
|
||||
rows = self._gen_rows(binary, charset=encoding)
|
||||
|
||||
header = next(rows)
|
||||
fields_map = self._get_fields_map(serializer_cls)
|
||||
header = [fields_map.get(name.strip('*'), '') for name in header]
|
||||
|
||||
data = []
|
||||
for row in rows:
|
||||
row = self._process_row(row)
|
||||
row_data = dict(zip(header, row))
|
||||
row_data = self._process_row_data(row_data)
|
||||
data.append(row_data)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise ParseError('CSV parse error!')
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
import pyexcel
|
||||
from .base import BaseFileParser
|
||||
|
||||
|
||||
class ExcelFileParser(BaseFileParser):
|
||||
|
||||
media_type = 'text/xlsx'
|
||||
|
||||
def generate_rows(self, stream_data):
|
||||
workbook = pyexcel.get_book(file_type='xlsx', file_content=stream_data)
|
||||
# 默认获取第一个工作表sheet
|
||||
sheet = workbook.sheet_by_index(0)
|
||||
rows = sheet.rows()
|
||||
return rows
|
|
@ -1,6 +1,7 @@
|
|||
from rest_framework import renderers
|
||||
|
||||
from .csv import *
|
||||
from .excel import *
|
||||
|
||||
|
||||
class PassthroughRenderer(renderers.BaseRenderer):
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
import abc
|
||||
from datetime import datetime
|
||||
from rest_framework.renderers import BaseRenderer
|
||||
from rest_framework.utils import encoders, json
|
||||
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class BaseFileRenderer(BaseRenderer):
|
||||
# 渲染模版标识, 导入、导出、更新模版: ['import', 'update', 'export']
|
||||
template = 'export'
|
||||
serializer = None
|
||||
|
||||
@staticmethod
|
||||
def _check_validation_data(data):
|
||||
detail_key = "detail"
|
||||
if detail_key in data:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _json_format_response(response_data):
|
||||
return json.dumps(response_data)
|
||||
|
||||
def set_response_disposition(self, response):
|
||||
serializer = self.serializer
|
||||
if response and hasattr(serializer, 'Meta') and hasattr(serializer.Meta, "model"):
|
||||
model_name = serializer.Meta.model.__name__.lower()
|
||||
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
filename = "{}_{}.{}".format(model_name, now, self.format)
|
||||
disposition = 'attachment; filename="{}"'.format(filename)
|
||||
response['Content-Disposition'] = disposition
|
||||
|
||||
def get_rendered_fields(self):
|
||||
fields = self.serializer.fields
|
||||
if self.template == 'import':
|
||||
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
|
||||
elif self.template == 'update':
|
||||
return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
|
||||
else:
|
||||
return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
|
||||
|
||||
@staticmethod
|
||||
def get_column_titles(render_fields):
|
||||
return [
|
||||
'*{}'.format(field.label) if field.required else str(field.label)
|
||||
for field in render_fields
|
||||
]
|
||||
|
||||
def process_data(self, data):
|
||||
results = data['results'] if 'results' in data else data
|
||||
|
||||
if isinstance(results, dict):
|
||||
results = [results]
|
||||
|
||||
if self.template == 'import':
|
||||
results = [results[0]] if results else results
|
||||
|
||||
else:
|
||||
# 限制数据数量
|
||||
results = results[:10000]
|
||||
# 会将一些 UUID 字段转化为 string
|
||||
results = json.loads(json.dumps(results, cls=encoders.JSONEncoder))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def generate_rows(data, render_fields):
|
||||
for item in data:
|
||||
row = []
|
||||
for field in render_fields:
|
||||
value = item.get(field.field_name)
|
||||
value = str(value) if value else ''
|
||||
row.append(value)
|
||||
yield row
|
||||
|
||||
@abc.abstractmethod
|
||||
def initial_writer(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def write_column_titles(self, column_titles):
|
||||
self.write_row(column_titles)
|
||||
|
||||
def write_rows(self, rows):
|
||||
for row in rows:
|
||||
self.write_row(row)
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_row(self, row):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_rendered_value(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||
if data is None:
|
||||
return bytes()
|
||||
|
||||
if not self._check_validation_data(data):
|
||||
return self._json_format_response(data)
|
||||
|
||||
try:
|
||||
renderer_context = renderer_context or {}
|
||||
request = renderer_context['request']
|
||||
response = renderer_context['response']
|
||||
view = renderer_context['view']
|
||||
self.template = request.query_params.get('template', 'export')
|
||||
self.serializer = view.get_serializer()
|
||||
self.set_response_disposition(response)
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
value = 'The resource not support export!'.encode('utf-8')
|
||||
return value
|
||||
|
||||
try:
|
||||
rendered_fields = self.get_rendered_fields()
|
||||
column_titles = self.get_column_titles(rendered_fields)
|
||||
data = self.process_data(data)
|
||||
rows = self.generate_rows(data, rendered_fields)
|
||||
self.initial_writer()
|
||||
self.write_column_titles(column_titles)
|
||||
self.write_rows(rows)
|
||||
value = self.get_rendered_value()
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
value = 'Render error! ({})'.format(self.media_type).encode('utf-8')
|
||||
return value
|
||||
|
||||
return value
|
||||
|
|
@ -1,83 +1,30 @@
|
|||
# ~*~ coding: utf-8 ~*~
|
||||
#
|
||||
|
||||
import unicodecsv
|
||||
import codecs
|
||||
from datetime import datetime
|
||||
|
||||
import unicodecsv
|
||||
from six import BytesIO
|
||||
from rest_framework.renderers import BaseRenderer
|
||||
from rest_framework.utils import encoders, json
|
||||
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
from .base import BaseFileRenderer
|
||||
|
||||
|
||||
class JMSCSVRender(BaseRenderer):
|
||||
|
||||
class CSVFileRenderer(BaseFileRenderer):
|
||||
media_type = 'text/csv'
|
||||
format = 'csv'
|
||||
|
||||
@staticmethod
|
||||
def _get_show_fields(fields, template):
|
||||
if template == 'import':
|
||||
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
|
||||
elif template == 'update':
|
||||
return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
|
||||
else:
|
||||
return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
|
||||
writer = None
|
||||
buffer = None
|
||||
|
||||
@staticmethod
|
||||
def _gen_table(data, fields):
|
||||
data = data[:10000]
|
||||
yield ['*{}'.format(f.label) if f.required else f.label for f in fields]
|
||||
def initial_writer(self):
|
||||
csv_buffer = BytesIO()
|
||||
csv_buffer.write(codecs.BOM_UTF8)
|
||||
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8')
|
||||
self.buffer = csv_buffer
|
||||
self.writer = csv_writer
|
||||
|
||||
for item in data:
|
||||
row = [item.get(f.field_name) for f in fields]
|
||||
yield row
|
||||
|
||||
def set_response_disposition(self, serializer, context):
|
||||
response = context.get('response')
|
||||
if response and hasattr(serializer, 'Meta') and \
|
||||
hasattr(serializer.Meta, "model"):
|
||||
model_name = serializer.Meta.model.__name__.lower()
|
||||
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
filename = "{}_{}.csv".format(model_name, now)
|
||||
disposition = 'attachment; filename="{}"'.format(filename)
|
||||
response['Content-Disposition'] = disposition
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
renderer_context = renderer_context or {}
|
||||
request = renderer_context['request']
|
||||
template = request.query_params.get('template', 'export')
|
||||
view = renderer_context['view']
|
||||
|
||||
if isinstance(data, dict):
|
||||
data = data.get("results", [])
|
||||
|
||||
if template == 'import':
|
||||
data = [data[0]] if data else data
|
||||
|
||||
data = json.loads(json.dumps(data, cls=encoders.JSONEncoder))
|
||||
|
||||
try:
|
||||
serializer = view.get_serializer()
|
||||
self.set_response_disposition(serializer, renderer_context)
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
value = 'The resource not support export!'.encode('utf-8')
|
||||
else:
|
||||
fields = serializer.fields
|
||||
show_fields = self._get_show_fields(fields, template)
|
||||
table = self._gen_table(data, show_fields)
|
||||
|
||||
csv_buffer = BytesIO()
|
||||
csv_buffer.write(codecs.BOM_UTF8)
|
||||
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8')
|
||||
for row in table:
|
||||
csv_writer.writerow(row)
|
||||
|
||||
value = csv_buffer.getvalue()
|
||||
def write_row(self, row):
|
||||
self.writer.writerow(row)
|
||||
|
||||
def get_rendered_value(self):
|
||||
value = self.buffer.getvalue()
|
||||
return value
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
from openpyxl import Workbook
|
||||
from openpyxl.writer.excel import save_virtual_workbook
|
||||
|
||||
from .base import BaseFileRenderer
|
||||
|
||||
|
||||
class ExcelFileRenderer(BaseFileRenderer):
|
||||
media_type = "application/xlsx"
|
||||
format = "xlsx"
|
||||
|
||||
wb = None
|
||||
ws = None
|
||||
row_count = 0
|
||||
|
||||
def initial_writer(self):
|
||||
self.wb = Workbook()
|
||||
self.ws = self.wb.active
|
||||
|
||||
def write_row(self, row):
|
||||
self.row_count += 1
|
||||
column_count = 0
|
||||
for cell_value in row:
|
||||
column_count += 1
|
||||
self.ws.cell(row=self.row_count, column=column_count, value=cell_value)
|
||||
|
||||
def get_rendered_value(self):
|
||||
value = save_virtual_workbook(self.wb)
|
||||
return value
|
|
@ -12,13 +12,16 @@ REST_FRAMEWORK = {
|
|||
'DEFAULT_RENDERER_CLASSES': (
|
||||
'rest_framework.renderers.JSONRenderer',
|
||||
# 'rest_framework.renderers.BrowsableAPIRenderer',
|
||||
'common.drf.renders.JMSCSVRender',
|
||||
'common.drf.renders.CSVFileRenderer',
|
||||
'common.drf.renders.ExcelFileRenderer',
|
||||
|
||||
),
|
||||
'DEFAULT_PARSER_CLASSES': (
|
||||
'rest_framework.parsers.JSONParser',
|
||||
'rest_framework.parsers.FormParser',
|
||||
'rest_framework.parsers.MultiPartParser',
|
||||
'common.drf.parsers.JMSCSVParser',
|
||||
'common.drf.parsers.CSVFileParser',
|
||||
'common.drf.parsers.ExcelFileParser',
|
||||
'rest_framework.parsers.FileUploadParser',
|
||||
),
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': (
|
||||
|
|
|
@ -105,3 +105,6 @@ azure-mgmt-compute==4.6.2
|
|||
azure-mgmt-network==2.7.0
|
||||
msrestazure==0.6.4
|
||||
adal==1.2.5
|
||||
openpyxl==3.0.5
|
||||
pyexcel==0.6.6
|
||||
pyexcel-xlsx==0.6.0
|
||||
|
|
Loading…
Reference in New Issue