mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			140 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			140 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
import abc
 | 
						|
from datetime import datetime
 | 
						|
from rest_framework.renderers import BaseRenderer
 | 
						|
from rest_framework.utils import encoders, json
 | 
						|
 | 
						|
from common.utils import get_logger
 | 
						|
 | 
						|
logger = get_logger(__file__)
 | 
						|
 | 
						|
 | 
						|
class BaseFileRenderer(BaseRenderer):
 | 
						|
    # 渲染模版标识, 导入、导出、更新模版: ['import', 'update', 'export']
 | 
						|
    template = 'export'
 | 
						|
    serializer = None
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _check_validation_data(data):
 | 
						|
        detail_key = "detail"
 | 
						|
        if detail_key in data:
 | 
						|
            return False
 | 
						|
        return True
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _json_format_response(response_data):
 | 
						|
        return json.dumps(response_data)
 | 
						|
 | 
						|
    def set_response_disposition(self, response):
 | 
						|
        serializer = self.serializer
 | 
						|
        if response and hasattr(serializer, 'Meta') and hasattr(serializer.Meta, "model"):
 | 
						|
            filename_prefix = serializer.Meta.model.__name__.lower()
 | 
						|
        else:
 | 
						|
            filename_prefix = 'download'
 | 
						|
        now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
 | 
						|
        filename = "{}_{}.{}".format(filename_prefix, now, self.format)
 | 
						|
        disposition = 'attachment; filename="{}"'.format(filename)
 | 
						|
        response['Content-Disposition'] = disposition
 | 
						|
 | 
						|
    def get_rendered_fields(self):
 | 
						|
        fields = self.serializer.fields
 | 
						|
        if self.template == 'import':
 | 
						|
            return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
 | 
						|
        elif self.template == 'update':
 | 
						|
            return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
 | 
						|
        else:
 | 
						|
            return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def get_column_titles(render_fields):
 | 
						|
        return [
 | 
						|
            '*{}'.format(field.label) if field.required else str(field.label)
 | 
						|
            for field in render_fields
 | 
						|
        ]
 | 
						|
 | 
						|
    def process_data(self, data):
 | 
						|
        results = data['results'] if 'results' in data else data
 | 
						|
 | 
						|
        if isinstance(results, dict):
 | 
						|
            results = [results]
 | 
						|
 | 
						|
        if self.template == 'import':
 | 
						|
            results = [results[0]] if results else results
 | 
						|
 | 
						|
        else:
 | 
						|
            # 限制数据数量
 | 
						|
            results = results[:10000]
 | 
						|
        # 会将一些 UUID 字段转化为 string
 | 
						|
        results = json.loads(json.dumps(results, cls=encoders.JSONEncoder))
 | 
						|
        return results
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def generate_rows(data, render_fields):
 | 
						|
        for item in data:
 | 
						|
            row = []
 | 
						|
            for field in render_fields:
 | 
						|
                value = item.get(field.field_name)
 | 
						|
                if value is None:
 | 
						|
                    value = ''
 | 
						|
                elif isinstance(value, dict):
 | 
						|
                    value = json.dumps(value, ensure_ascii=False)
 | 
						|
                else:
 | 
						|
                    value = str(value)
 | 
						|
                row.append(value)
 | 
						|
            yield row
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    def initial_writer(self):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def write_column_titles(self, column_titles):
 | 
						|
        self.write_row(column_titles)
 | 
						|
 | 
						|
    def write_rows(self, rows):
 | 
						|
        for row in rows:
 | 
						|
            self.write_row(row)
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    def write_row(self, row):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    @abc.abstractmethod
 | 
						|
    def get_rendered_value(self):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def render(self, data, accepted_media_type=None, renderer_context=None):
 | 
						|
        if data is None:
 | 
						|
            return bytes()
 | 
						|
 | 
						|
        if not self._check_validation_data(data):
 | 
						|
            return self._json_format_response(data)
 | 
						|
 | 
						|
        try:
 | 
						|
            renderer_context = renderer_context or {}
 | 
						|
            request = renderer_context['request']
 | 
						|
            response = renderer_context['response']
 | 
						|
            view = renderer_context['view']
 | 
						|
            self.template = request.query_params.get('template', 'export')
 | 
						|
            self.serializer = view.get_serializer()
 | 
						|
            self.set_response_disposition(response)
 | 
						|
        except Exception as e:
 | 
						|
            logger.debug(e, exc_info=True)
 | 
						|
            value = 'The resource not support export!'.encode('utf-8')
 | 
						|
            return value
 | 
						|
 | 
						|
        try:
 | 
						|
            rendered_fields = self.get_rendered_fields()
 | 
						|
            column_titles = self.get_column_titles(rendered_fields)
 | 
						|
            data = self.process_data(data)
 | 
						|
            rows = self.generate_rows(data, rendered_fields)
 | 
						|
            self.initial_writer()
 | 
						|
            self.write_column_titles(column_titles)
 | 
						|
            self.write_rows(rows)
 | 
						|
            value = self.get_rendered_value()
 | 
						|
        except Exception as e:
 | 
						|
            logger.debug(e, exc_info=True)
 | 
						|
            value = 'Render error! ({})'.format(self.media_type).encode('utf-8')
 | 
						|
            return value
 | 
						|
 | 
						|
        return value
 | 
						|
 |