jumpserver/apps/ops/celery/logger.py

224 lines
6.4 KiB
Python

from logging import StreamHandler
from threading import get_ident
from celery import current_task
from celery.signals import task_prerun, task_postrun
from django.conf import settings
from kombu import Connection, Exchange, Queue, Producer
from kombu.mixins import ConsumerMixin
from .utils import get_celery_task_log_path
from ..const import CELERY_LOG_MAGIC_MARK
routing_key = 'celery_log'
celery_log_exchange = Exchange('celery_log_exchange', type='direct')
celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)]
class CeleryLoggerConsumer(ConsumerMixin):
def __init__(self):
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
def get_consumers(self, Consumer, channel):
return [Consumer(queues=celery_log_queue,
accept=['pickle', 'json'],
callbacks=[self.process_task])
]
def handle_task_start(self, task_id, message):
pass
def handle_task_end(self, task_id, message):
pass
def handle_task_log(self, task_id, msg, message):
pass
def process_task(self, body, message):
action = body.get('action')
task_id = body.get('task_id')
msg = body.get('msg')
if action == CeleryLoggerProducer.ACTION_TASK_LOG:
self.handle_task_log(task_id, msg, message)
elif action == CeleryLoggerProducer.ACTION_TASK_START:
self.handle_task_start(task_id, message)
elif action == CeleryLoggerProducer.ACTION_TASK_END:
self.handle_task_end(task_id, message)
class CeleryLoggerProducer:
ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3)
def __init__(self):
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
@property
def producer(self):
return Producer(self.connection)
def publish(self, payload):
self.producer.publish(
payload, serializer='json', exchange=celery_log_exchange,
declare=[celery_log_exchange], routing_key=routing_key
)
def log(self, task_id, msg):
payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG}
return self.publish(payload)
def read(self):
pass
def flush(self):
pass
def task_end(self, task_id):
payload = {'task_id': task_id, 'action': self.ACTION_TASK_END}
return self.publish(payload)
def task_start(self, task_id):
payload = {'task_id': task_id, 'action': self.ACTION_TASK_START}
return self.publish(payload)
class CeleryTaskLoggerHandler(StreamHandler):
terminator = '\r\n'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
task_prerun.connect(self.on_task_start)
task_postrun.connect(self.on_start_end)
@staticmethod
def get_current_task_id():
if not current_task:
return
task_id = current_task.request.root_id
return task_id
def on_task_start(self, sender, task_id, **kwargs):
return self.handle_task_start(task_id)
def on_start_end(self, sender, task_id, **kwargs):
return self.handle_task_end(task_id)
def after_task_publish(self, sender, body, **kwargs):
pass
def emit(self, record):
task_id = self.get_current_task_id()
if not task_id:
return
try:
self.write_task_log(task_id, record)
self.flush()
except Exception:
self.handleError(record)
def write_task_log(self, task_id, msg):
pass
def handle_task_start(self, task_id):
pass
def handle_task_end(self, task_id):
pass
class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler):
@staticmethod
def get_current_thread_id():
return str(get_ident())
def emit(self, record):
thread_id = self.get_current_thread_id()
try:
self.write_thread_task_log(thread_id, record)
self.flush()
except ValueError:
self.handleError(record)
def write_thread_task_log(self, thread_id, msg):
pass
def handle_task_start(self, task_id):
pass
def handle_task_end(self, task_id):
pass
def handleError(self, record) -> None:
pass
class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler):
def __init__(self):
self.producer = CeleryLoggerProducer()
super().__init__(stream=None)
def write_task_log(self, task_id, record):
msg = self.format(record)
self.producer.log(task_id, msg)
def flush(self):
self.producer.flush()
class CeleryTaskFileHandler(CeleryTaskLoggerHandler):
def __init__(self, *args, **kwargs):
self.f = None
super().__init__(*args, **kwargs)
def emit(self, record):
msg = self.format(record)
if not self.f or self.f.closed:
return
self.f.write(msg)
self.f.write(self.terminator)
self.flush()
def flush(self):
self.f and self.f.flush()
def handle_task_start(self, task_id):
log_path = get_celery_task_log_path(task_id)
self.f = open(log_path, 'a')
def handle_task_end(self, task_id):
self.f and self.f.close()
class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
def __init__(self, *args, **kwargs):
self.thread_id_fd_mapper = {}
self.task_id_thread_id_mapper = {}
super().__init__(*args, **kwargs)
def write_thread_task_log(self, thread_id, record):
f = self.thread_id_fd_mapper.get(thread_id, None)
if not f:
raise ValueError('Not found thread task file')
msg = self.format(record)
f.write(msg.encode())
f.write(self.terminator.encode())
f.flush()
def flush(self):
for f in self.thread_id_fd_mapper.values():
f.flush()
def handle_task_start(self, task_id):
log_path = get_celery_task_log_path(task_id)
thread_id = self.get_current_thread_id()
self.task_id_thread_id_mapper[task_id] = thread_id
f = open(log_path, 'ab')
self.thread_id_fd_mapper[thread_id] = f
def handle_task_end(self, task_id):
ident_id = self.task_id_thread_id_mapper.get(task_id, '')
f = self.thread_id_fd_mapper.pop(ident_id, None)
if f and not f.closed:
f.write(CELERY_LOG_MAGIC_MARK)
f.close()
self.task_id_thread_id_mapper.pop(task_id, None)