jumpserver/apps/common/renders/csv.py

84 lines
2.8 KiB
Python

# ~*~ coding: utf-8 ~*~
#
import unicodecsv
from datetime import datetime
from six import BytesIO
from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json
from ..utils import get_logger
logger = get_logger(__file__)
class JMSCSVRender(BaseRenderer):
media_type = 'text/csv'
format = 'csv'
@staticmethod
def _get_header(fields, template):
if template == 'import':
header = [
k for k, v in fields.items()
if not v.read_only and k != 'org_id'
]
elif template == 'update':
header = [k for k, v in fields.items() if not v.read_only]
else:
# template in ['export']
header = [k for k, v in fields.items() if not v.write_only]
return header
@staticmethod
def _gen_table(data, header, labels=None):
labels = labels or {}
yield [labels.get(k, k) for k in header]
for item in data:
row = [item.get(key) for key in header]
yield row
def set_response_disposition(self, serializer, context):
response = context.get('response')
if response and hasattr(serializer, 'Meta') and \
hasattr(serializer.Meta, "model"):
model_name = serializer.Meta.model.__name__.lower()
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = "{}_{}.csv".format(model_name, now)
disposition = 'attachment; filename="{}"'.format(filename)
response['Content-Disposition'] = disposition
def render(self, data, media_type=None, renderer_context=None):
renderer_context = renderer_context or {}
encoding = renderer_context.get('encoding', 'utf-8')
request = renderer_context['request']
template = request.query_params.get('template', 'export')
view = renderer_context['view']
data = json.loads(json.dumps(data, cls=encoders.JSONEncoder))
if template == 'import':
data = [data[0]] if data else data
try:
serializer = view.get_serializer()
self.set_response_disposition(serializer, renderer_context)
except Exception as e:
logger.debug(e, exc_info=True)
value = 'The resource not support export!'.encode('utf-8')
else:
fields = serializer.get_fields()
header = self._get_header(fields, template)
labels = {k: v.label for k, v in fields.items() if v.label}
table = self._gen_table(data, header, labels)
csv_buffer = BytesIO()
csv_writer = unicodecsv.writer(csv_buffer, encoding=encoding)
for row in table:
csv_writer.writerow(row)
value = csv_buffer.getvalue()
return value