mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			224 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			224 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
from logging import StreamHandler
 | 
						|
from threading import get_ident
 | 
						|
 | 
						|
from celery import current_task
 | 
						|
from celery.signals import task_prerun, task_postrun
 | 
						|
from django.conf import settings
 | 
						|
from kombu import Connection, Exchange, Queue, Producer
 | 
						|
from kombu.mixins import ConsumerMixin
 | 
						|
 | 
						|
from .utils import get_celery_task_log_path
 | 
						|
from ..const import CELERY_LOG_MAGIC_MARK
 | 
						|
 | 
						|
routing_key = 'celery_log'
 | 
						|
celery_log_exchange = Exchange('celery_log_exchange', type='direct')
 | 
						|
celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)]
 | 
						|
 | 
						|
 | 
						|
class CeleryLoggerConsumer(ConsumerMixin):
 | 
						|
    def __init__(self):
 | 
						|
        self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
 | 
						|
 | 
						|
    def get_consumers(self, Consumer, channel):
 | 
						|
        return [Consumer(queues=celery_log_queue,
 | 
						|
                         accept=['pickle', 'json'],
 | 
						|
                         callbacks=[self.process_task])
 | 
						|
                ]
 | 
						|
 | 
						|
    def handle_task_start(self, task_id, message):
 | 
						|
        pass
 | 
						|
 | 
						|
    def handle_task_end(self, task_id, message):
 | 
						|
        pass
 | 
						|
 | 
						|
    def handle_task_log(self, task_id, msg, message):
 | 
						|
        pass
 | 
						|
 | 
						|
    def process_task(self, body, message):
 | 
						|
        action = body.get('action')
 | 
						|
        task_id = body.get('task_id')
 | 
						|
        msg = body.get('msg')
 | 
						|
        if action == CeleryLoggerProducer.ACTION_TASK_LOG:
 | 
						|
            self.handle_task_log(task_id, msg, message)
 | 
						|
        elif action == CeleryLoggerProducer.ACTION_TASK_START:
 | 
						|
            self.handle_task_start(task_id, message)
 | 
						|
        elif action == CeleryLoggerProducer.ACTION_TASK_END:
 | 
						|
            self.handle_task_end(task_id, message)
 | 
						|
 | 
						|
 | 
						|
class CeleryLoggerProducer:
 | 
						|
    ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3)
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
 | 
						|
 | 
						|
    @property
 | 
						|
    def producer(self):
 | 
						|
        return Producer(self.connection)
 | 
						|
 | 
						|
    def publish(self, payload):
 | 
						|
        self.producer.publish(
 | 
						|
            payload, serializer='json', exchange=celery_log_exchange,
 | 
						|
            declare=[celery_log_exchange], routing_key=routing_key
 | 
						|
        )
 | 
						|
 | 
						|
    def log(self, task_id, msg):
 | 
						|
        payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG}
 | 
						|
        return self.publish(payload)
 | 
						|
 | 
						|
    def read(self):
 | 
						|
        pass
 | 
						|
 | 
						|
    def flush(self):
 | 
						|
        pass
 | 
						|
 | 
						|
    def task_end(self, task_id):
 | 
						|
        payload = {'task_id': task_id, 'action': self.ACTION_TASK_END}
 | 
						|
        return self.publish(payload)
 | 
						|
 | 
						|
    def task_start(self, task_id):
 | 
						|
        payload = {'task_id': task_id, 'action': self.ACTION_TASK_START}
 | 
						|
        return self.publish(payload)
 | 
						|
 | 
						|
 | 
						|
class CeleryTaskLoggerHandler(StreamHandler):
 | 
						|
    terminator = '\r\n'
 | 
						|
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
        task_prerun.connect(self.on_task_start)
 | 
						|
        task_postrun.connect(self.on_start_end)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def get_current_task_id():
 | 
						|
        if not current_task:
 | 
						|
            return
 | 
						|
        task_id = current_task.request.root_id
 | 
						|
        return task_id
 | 
						|
 | 
						|
    def on_task_start(self, sender, task_id, **kwargs):
 | 
						|
        return self.handle_task_start(task_id)
 | 
						|
 | 
						|
    def on_start_end(self, sender, task_id, **kwargs):
 | 
						|
        return self.handle_task_end(task_id)
 | 
						|
 | 
						|
    def after_task_publish(self, sender, body, **kwargs):
 | 
						|
        pass
 | 
						|
 | 
						|
    def emit(self, record):
 | 
						|
        task_id = self.get_current_task_id()
 | 
						|
        if not task_id:
 | 
						|
            return
 | 
						|
        try:
 | 
						|
            self.write_task_log(task_id, record)
 | 
						|
            self.flush()
 | 
						|
        except Exception:
 | 
						|
            self.handleError(record)
 | 
						|
 | 
						|
    def write_task_log(self, task_id, msg):
 | 
						|
        pass
 | 
						|
 | 
						|
    def handle_task_start(self, task_id):
 | 
						|
        pass
 | 
						|
 | 
						|
    def handle_task_end(self, task_id):
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler):
 | 
						|
    @staticmethod
 | 
						|
    def get_current_thread_id():
 | 
						|
        return str(get_ident())
 | 
						|
 | 
						|
    def emit(self, record):
 | 
						|
        thread_id = self.get_current_thread_id()
 | 
						|
        try:
 | 
						|
            self.write_thread_task_log(thread_id, record)
 | 
						|
            self.flush()
 | 
						|
        except ValueError:
 | 
						|
            self.handleError(record)
 | 
						|
 | 
						|
    def write_thread_task_log(self, thread_id, msg):
 | 
						|
        pass
 | 
						|
 | 
						|
    def handle_task_start(self, task_id):
 | 
						|
        pass
 | 
						|
 | 
						|
    def handle_task_end(self, task_id):
 | 
						|
        pass
 | 
						|
 | 
						|
    def handleError(self, record) -> None:
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler):
 | 
						|
    def __init__(self):
 | 
						|
        self.producer = CeleryLoggerProducer()
 | 
						|
        super().__init__(stream=None)
 | 
						|
 | 
						|
    def write_task_log(self, task_id, record):
 | 
						|
        msg = self.format(record)
 | 
						|
        self.producer.log(task_id, msg)
 | 
						|
 | 
						|
    def flush(self):
 | 
						|
        self.producer.flush()
 | 
						|
 | 
						|
 | 
						|
class CeleryTaskFileHandler(CeleryTaskLoggerHandler):
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        self.f = None
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
 | 
						|
    def emit(self, record):
 | 
						|
        msg = self.format(record)
 | 
						|
        if not self.f or self.f.closed:
 | 
						|
            return
 | 
						|
        self.f.write(msg)
 | 
						|
        self.f.write(self.terminator)
 | 
						|
        self.flush()
 | 
						|
 | 
						|
    def flush(self):
 | 
						|
        self.f and self.f.flush()
 | 
						|
 | 
						|
    def handle_task_start(self, task_id):
 | 
						|
        log_path = get_celery_task_log_path(task_id)
 | 
						|
        self.f = open(log_path, 'a')
 | 
						|
 | 
						|
    def handle_task_end(self, task_id):
 | 
						|
        self.f and self.f.close()
 | 
						|
 | 
						|
 | 
						|
class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        self.thread_id_fd_mapper = {}
 | 
						|
        self.task_id_thread_id_mapper = {}
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
 | 
						|
    def write_thread_task_log(self, thread_id, record):
 | 
						|
        f = self.thread_id_fd_mapper.get(thread_id, None)
 | 
						|
        if not f:
 | 
						|
            raise ValueError('Not found thread task file')
 | 
						|
        msg = self.format(record)
 | 
						|
        f.write(msg.encode())
 | 
						|
        f.write(self.terminator.encode())
 | 
						|
        f.flush()
 | 
						|
 | 
						|
    def flush(self):
 | 
						|
        for f in self.thread_id_fd_mapper.values():
 | 
						|
            f.flush()
 | 
						|
 | 
						|
    def handle_task_start(self, task_id):
 | 
						|
        log_path = get_celery_task_log_path(task_id)
 | 
						|
        thread_id = self.get_current_thread_id()
 | 
						|
        self.task_id_thread_id_mapper[task_id] = thread_id
 | 
						|
        f = open(log_path, 'ab')
 | 
						|
        self.thread_id_fd_mapper[thread_id] = f
 | 
						|
 | 
						|
    def handle_task_end(self, task_id):
 | 
						|
        ident_id = self.task_id_thread_id_mapper.get(task_id, '')
 | 
						|
        f = self.thread_id_fd_mapper.pop(ident_id, None)
 | 
						|
        if f and not f.closed:
 | 
						|
            f.write(CELERY_LOG_MAGIC_MARK)
 | 
						|
            f.close()
 | 
						|
        self.task_id_thread_id_mapper.pop(task_id, None)
 |