perf: 优化导入导出

perf: remove debug

perf: 修改账号导入导出

perf: 去掉一些 debug
pull/9921/head
ibuler 2023-03-10 15:52:07 +08:00 committed by Jiangjie.Bai
parent 3658ecce0c
commit fa3bfceddc
6 changed files with 132 additions and 49 deletions

View File

@ -26,6 +26,13 @@ __all__ = [
class AssetProtocolsSerializer(serializers.ModelSerializer): class AssetProtocolsSerializer(serializers.ModelSerializer):
port = serializers.IntegerField(required=False, allow_null=True, max_value=65535, min_value=1) port = serializers.IntegerField(required=False, allow_null=True, max_value=65535, min_value=1)
def to_file_representation(self, data):
return '{name}/{port}'.format(**data)
def to_file_internal_value(self, data):
name, port = data.split('/')
return {'name': name, 'port': port}
class Meta: class Meta:
model = Protocol model = Protocol
fields = ['name', 'port'] fields = ['name', 'port']
@ -121,7 +128,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
type = LabeledChoiceField(choices=AllTypes.choices(), read_only=True, label=_('Type')) type = LabeledChoiceField(choices=AllTypes.choices(), read_only=True, label=_('Type'))
labels = AssetLabelSerializer(many=True, required=False, label=_('Label')) labels = AssetLabelSerializer(many=True, required=False, label=_('Label'))
protocols = AssetProtocolsSerializer(many=True, required=False, label=_('Protocols'), default=()) protocols = AssetProtocolsSerializer(many=True, required=False, label=_('Protocols'), default=())
accounts = AssetAccountSerializer(many=True, required=False, write_only=True, label=_('Account')) accounts = AssetAccountSerializer(many=True, required=False, allow_null=True, write_only=True, label=_('Account'))
nodes_display = serializers.ListField(read_only=True, label=_("Node path"))
class Meta: class Meta:
model = Asset model = Asset
@ -133,11 +141,11 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
'nodes_display', 'accounts' 'nodes_display', 'accounts'
] ]
read_only_fields = [ read_only_fields = [
'category', 'type', 'connectivity', 'category', 'type', 'connectivity', 'auto_info',
'date_verified', 'created_by', 'date_created', 'date_verified', 'created_by', 'date_created',
'auto_info',
] ]
fields = fields_small + fields_fk + fields_m2m + read_only_fields fields = fields_small + fields_fk + fields_m2m + read_only_fields
fields_unexport = ['auto_info']
extra_kwargs = { extra_kwargs = {
'auto_info': {'label': _('Auto info')}, 'auto_info': {'label': _('Auto info')},
'name': {'label': _("Name")}, 'name': {'label': _("Name")},

View File

@ -3,13 +3,12 @@
from typing import Callable from typing import Callable
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from rest_framework.response import Response
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response
from common.const.http import POST from common.const.http import POST
__all__ = ['SuggestionMixin', 'RenderToJsonMixin'] __all__ = ['SuggestionMixin', 'RenderToJsonMixin']

View File

@ -1,11 +1,15 @@
import abc import abc
import json
import codecs import codecs
from rest_framework import serializers import json
import re
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.parsers import BaseParser from rest_framework import serializers
from rest_framework import status from rest_framework import status
from rest_framework.exceptions import ParseError, APIException from rest_framework.exceptions import ParseError, APIException
from rest_framework.parsers import BaseParser
from common.serializers.fields import ObjectRelatedField
from common.utils import get_logger from common.utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
@ -18,11 +22,11 @@ class FileContentOverflowedError(APIException):
class BaseFileParser(BaseParser): class BaseFileParser(BaseParser):
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10 FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10
serializer_cls = None serializer_cls = None
serializer_fields = None serializer_fields = None
obj_pattern = re.compile(r'^(.+)\(([a-z0-9-]+)\)$')
def check_content_length(self, meta): def check_content_length(self, meta):
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))) content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
@ -74,7 +78,7 @@ class BaseFileParser(BaseParser):
return s.translate(trans_table) return s.translate(trans_table)
@classmethod @classmethod
def process_row(cls, row): def load_row(cls, row):
""" """
构建json数据前的行处理 构建json数据前的行处理
""" """
@ -84,33 +88,59 @@ class BaseFileParser(BaseParser):
col = cls._replace_chinese_quote(col) col = cls._replace_chinese_quote(col)
# 列表/字典转换 # 列表/字典转换
if isinstance(col, str) and ( if isinstance(col, str) and (
(col.startswith('[') and col.endswith(']')) (col.startswith('[') and col.endswith(']')) or
or
(col.startswith("{") and col.endswith("}")) (col.startswith("{") and col.endswith("}"))
): ):
col = json.loads(col) col = json.loads(col)
new_row.append(col) new_row.append(col)
return new_row return new_row
def id_name_to_obj(self, v):
if not v or not isinstance(v, str):
return v
matched = self.obj_pattern.match(v)
if not matched:
return v
obj_name, obj_id = matched.groups()
if len(obj_id) < 36:
obj_id = int(obj_id)
return {'pk': obj_id, 'name': obj_name}
def parse_value(self, field, value):
if value is '-':
return None
elif hasattr(field, 'to_file_internal_value'):
value = field.to_file_internal_value(value)
elif isinstance(field, serializers.BooleanField):
value = value.lower() in ['true', '1', 'yes']
elif isinstance(field, serializers.ChoiceField):
value = value
elif isinstance(field, ObjectRelatedField):
if field.many:
value = [self.id_name_to_obj(v) for v in value]
else:
value = self.id_name_to_obj(value)
elif isinstance(field, serializers.ListSerializer):
value = [self.parse_value(field.child, v) for v in value]
elif isinstance(field, serializers.Serializer):
value = self.id_name_to_obj(value)
elif isinstance(field, serializers.ManyRelatedField):
value = [self.parse_value(field.child_relation, v) for v in value]
elif isinstance(field, serializers.ListField):
value = [self.parse_value(field.child, v) for v in value]
return value
def process_row_data(self, row_data): def process_row_data(self, row_data):
""" """
构建json数据后的行数据处理 构建json数据后的行数据处理
""" """
new_row_data = {} new_row = {}
serializer_fields = self.serializer_fields
for k, v in row_data.items(): for k, v in row_data.items():
if type(v) in [list, dict, int, bool] or (isinstance(v, str) and k.strip() and v.strip()): field = self.serializer_fields.get(k)
# 处理类似disk_info为字符串的'{}'的问题 v = self.parse_value(field, v)
if not isinstance(v, str) and isinstance(serializer_fields[k], serializers.CharField): new_row[k] = v
v = str(v) return new_row
# 处理 BooleanField 的问题, 导出是 'True', 'False'
if isinstance(v, str) and v.strip().lower() == 'true':
v = True
elif isinstance(v, str) and v.strip().lower() == 'false':
v = False
new_row_data[k] = v
return new_row_data
def generate_data(self, fields_name, rows): def generate_data(self, fields_name, rows):
data = [] data = []
@ -118,7 +148,7 @@ class BaseFileParser(BaseParser):
# 空行不处理 # 空行不处理
if not any(row): if not any(row):
continue continue
row = self.process_row(row) row = self.load_row(row)
row_data = dict(zip(fields_name, row)) row_data = dict(zip(fields_name, row))
row_data = self.process_row_data(row_data) row_data = self.process_row_data(row_data)
data.append(row_data) data.append(row_data)
@ -139,7 +169,6 @@ class BaseFileParser(BaseParser):
raise ParseError('The resource does not support imports!') raise ParseError('The resource does not support imports!')
self.check_content_length(meta) self.check_content_length(meta)
try: try:
stream_data = self.get_stream_data(stream) stream_data = self.get_stream_data(stream)
rows = self.generate_rows(stream_data) rows = self.generate_rows(stream_data)
@ -148,6 +177,7 @@ class BaseFileParser(BaseParser):
# 给 `common.mixins.api.RenderToJsonMixin` 提供,暂时只能耦合 # 给 `common.mixins.api.RenderToJsonMixin` 提供,暂时只能耦合
column_title_field_pairs = list(zip(column_titles, field_names)) column_title_field_pairs = list(zip(column_titles, field_names))
column_title_field_pairs = [(k, v) for k, v in column_title_field_pairs if k and v]
if not hasattr(request, 'jms_context'): if not hasattr(request, 'jms_context'):
request.jms_context = {} request.jms_context = {}
request.jms_context['column_title_field_pairs'] = column_title_field_pairs request.jms_context['column_title_field_pairs'] = column_title_field_pairs
@ -157,4 +187,3 @@ class BaseFileParser(BaseParser):
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
raise ParseError(_('Parse file error: {}').format(e)) raise ParseError(_('Parse file error: {}').format(e))

View File

@ -1,8 +1,11 @@
import abc import abc
from datetime import datetime from datetime import datetime
from rest_framework import serializers
from rest_framework.renderers import BaseRenderer from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json from rest_framework.utils import encoders, json
from common.serializers.fields import ObjectRelatedField
from common.utils import get_logger from common.utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
@ -38,18 +41,27 @@ class BaseFileRenderer(BaseRenderer):
def get_rendered_fields(self): def get_rendered_fields(self):
fields = self.serializer.fields fields = self.serializer.fields
if self.template == 'import': if self.template == 'import':
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id'] fields = [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
elif self.template == 'update': elif self.template == 'update':
return [v for k, v in fields.items() if not v.read_only and k != "org_id"] fields = [v for k, v in fields.items() if not v.read_only and k != "org_id"]
else: else:
return [v for k, v in fields.items() if not v.write_only and k != "org_id"] fields = [v for k, v in fields.items() if not v.write_only and k != "org_id"]
meta = getattr(self.serializer, 'Meta', None)
if meta:
fields_unexport = getattr(meta, 'fields_unexport', [])
fields = [v for v in fields if v.field_name not in fields_unexport]
return fields
@staticmethod @staticmethod
def get_column_titles(render_fields): def get_column_titles(render_fields):
return [ titles = []
'*{}'.format(field.label) if field.required else str(field.label) for field in render_fields:
for field in render_fields name = field.label
] if field.required:
name = '*' + name
titles.append(name)
return titles
def process_data(self, data): def process_data(self, data):
results = data['results'] if 'results' in data else data results = data['results'] if 'results' in data else data
@ -59,7 +71,6 @@ class BaseFileRenderer(BaseRenderer):
if self.template == 'import': if self.template == 'import':
results = [results[0]] if results else results results = [results[0]] if results else results
else: else:
# 限制数据数量 # 限制数据数量
results = results[:10000] results = results[:10000]
@ -68,17 +79,53 @@ class BaseFileRenderer(BaseRenderer):
return results return results
@staticmethod @staticmethod
def generate_rows(data, render_fields): def to_id_name(value):
if value is None:
return '-'
pk = str(value.get('id', '') or value.get('pk', ''))
name = value.get('name') or value.get('display_name', '')
return '{}({})'.format(name, pk)
@staticmethod
def to_choice_name(value):
if value is None:
return '-'
value = value.get('value', '')
return value
def render_value(self, field, value):
if value is None:
value = '-'
elif hasattr(field, 'to_file_representation'):
value = field.to_file_representation(value)
elif isinstance(value, bool):
value = 'Yes' if value else 'No'
elif isinstance(field, serializers.ChoiceField):
value = value.get('value', '')
elif isinstance(field, ObjectRelatedField):
if field.many:
value = [self.to_id_name(v) for v in value]
else:
value = self.to_id_name(value)
elif isinstance(field, serializers.ListSerializer):
value = [self.render_value(field.child, v) for v in value]
elif isinstance(field, serializers.Serializer) and value.get('id'):
value = self.to_id_name(value)
elif isinstance(field, serializers.ManyRelatedField):
value = [self.render_value(field.child_relation, v) for v in value]
elif isinstance(field, serializers.ListField):
value = [self.render_value(field.child, v) for v in value]
if not isinstance(value, str):
value = json.dumps(value, cls=encoders.JSONEncoder, ensure_ascii=False)
return str(value)
def generate_rows(self, data, render_fields):
for item in data: for item in data:
row = [] row = []
for field in render_fields: for field in render_fields:
value = item.get(field.field_name) value = item.get(field.field_name)
if value is None: value = self.render_value(field, value)
value = ''
elif isinstance(value, dict):
value = json.dumps(value, ensure_ascii=False)
else:
value = str(value)
row.append(value) row.append(value)
yield row yield row
@ -134,6 +181,4 @@ class BaseFileRenderer(BaseRenderer):
logger.debug(e, exc_info=True) logger.debug(e, exc_info=True)
value = 'Render error! ({})'.format(self.media_type).encode('utf-8') value = 'Render error! ({})'.format(self.media_type).encode('utf-8')
return value return value
return value return value

View File

@ -1,6 +1,6 @@
from openpyxl import Workbook from openpyxl import Workbook
from openpyxl.writer.excel import save_virtual_workbook
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
from openpyxl.writer.excel import save_virtual_workbook
from .base import BaseFileRenderer from .base import BaseFileRenderer
@ -23,8 +23,8 @@ class ExcelFileRenderer(BaseFileRenderer):
for cell_value in row: for cell_value in row:
# 处理非法字符 # 处理非法字符
column_count += 1 column_count += 1
cell_value = ILLEGAL_CHARACTERS_RE.sub(r'', cell_value) cell_value = ILLEGAL_CHARACTERS_RE.sub(r'', str(cell_value))
self.ws.cell(row=self.row_count, column=column_count, value=cell_value) self.ws.cell(row=self.row_count, column=column_count, value=str(cell_value))
def get_rendered_value(self): def get_rendered_value(self):
value = save_virtual_workbook(self.wb) value = save_virtual_workbook(self.wb)

View File

@ -142,6 +142,7 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer
# 在serializer 上定义的字段 # 在serializer 上定义的字段
fields_custom = ["login_blocked", "password_strategy"] fields_custom = ["login_blocked", "password_strategy"]
fields = fields_verbose + fields_fk + fields_m2m + fields_custom fields = fields_verbose + fields_fk + fields_m2m + fields_custom
fields_unexport = ["avatar_url", ]
read_only_fields = [ read_only_fields = [
"date_joined", "last_login", "created_by", "date_joined", "last_login", "created_by",
@ -167,6 +168,7 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer
"role": {"default": "User"}, "role": {"default": "User"},
"is_otp_secret_key_bound": {"label": _("Is OTP bound")}, "is_otp_secret_key_bound": {"label": _("Is OTP bound")},
"phone": {"validators": [PhoneValidator()]}, "phone": {"validators": [PhoneValidator()]},
'mfa_level': {'label': _("MFA level")},
} }
def validate_password(self, password): def validate_password(self, password):