|
|
|
from logging import StreamHandler
|
|
|
|
from threading import get_ident
|
|
|
|
|
|
|
|
from celery import current_task
|
|
|
|
from celery.signals import task_prerun, task_postrun
|
|
|
|
from django.conf import settings
|
|
|
|
from kombu import Connection, Exchange, Queue, Producer
|
|
|
|
from kombu.mixins import ConsumerMixin
|
|
|
|
|
|
|
|
from .utils import get_celery_task_log_path
|
|
|
|
from ..const import CELERY_LOG_MAGIC_MARK
|
|
|
|
|
|
|
|
routing_key = 'celery_log'
|
|
|
|
celery_log_exchange = Exchange('celery_log_exchange', type='direct')
|
|
|
|
celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)]
|
|
|
|
|
|
|
|
|
|
|
|
class CeleryLoggerConsumer(ConsumerMixin):
|
|
|
|
def __init__(self):
|
|
|
|
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
|
|
|
|
|
|
|
|
def get_consumers(self, Consumer, channel):
|
|
|
|
return [Consumer(queues=celery_log_queue,
|
|
|
|
accept=['pickle', 'json'],
|
|
|
|
callbacks=[self.process_task])
|
|
|
|
]
|
|
|
|
|
|
|
|
def handle_task_start(self, task_id, message):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def handle_task_end(self, task_id, message):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def handle_task_log(self, task_id, msg, message):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def process_task(self, body, message):
|
|
|
|
action = body.get('action')
|
|
|
|
task_id = body.get('task_id')
|
|
|
|
msg = body.get('msg')
|
|
|
|
if action == CeleryLoggerProducer.ACTION_TASK_LOG:
|
|
|
|
self.handle_task_log(task_id, msg, message)
|
|
|
|
elif action == CeleryLoggerProducer.ACTION_TASK_START:
|
|
|
|
self.handle_task_start(task_id, message)
|
|
|
|
elif action == CeleryLoggerProducer.ACTION_TASK_END:
|
|
|
|
self.handle_task_end(task_id, message)
|
|
|
|
|
|
|
|
|
|
|
|
class CeleryLoggerProducer:
|
|
|
|
ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3)
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def producer(self):
|
|
|
|
return Producer(self.connection)
|
|
|
|
|
|
|
|
def publish(self, payload):
|
|
|
|
self.producer.publish(
|
|
|
|
payload, serializer='json', exchange=celery_log_exchange,
|
|
|
|
declare=[celery_log_exchange], routing_key=routing_key
|
|
|
|
)
|
|
|
|
|
|
|
|
def log(self, task_id, msg):
|
|
|
|
payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG}
|
|
|
|
return self.publish(payload)
|
|
|
|
|
|
|
|
def read(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def flush(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def task_end(self, task_id):
|
|
|
|
payload = {'task_id': task_id, 'action': self.ACTION_TASK_END}
|
|
|
|
return self.publish(payload)
|
|
|
|
|
|
|
|
def task_start(self, task_id):
|
|
|
|
payload = {'task_id': task_id, 'action': self.ACTION_TASK_START}
|
|
|
|
return self.publish(payload)
|
|
|
|
|
|
|
|
|
|
|
|
class CeleryTaskLoggerHandler(StreamHandler):
|
|
|
|
terminator = '\r\n'
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
task_prerun.connect(self.on_task_start)
|
|
|
|
task_postrun.connect(self.on_start_end)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_current_task_id():
|
|
|
|
if not current_task:
|
|
|
|
return
|
|
|
|
task_id = current_task.request.root_id
|
|
|
|
return task_id
|
|
|
|
|
|
|
|
def on_task_start(self, sender, task_id, **kwargs):
|
|
|
|
return self.handle_task_start(task_id)
|
|
|
|
|
|
|
|
def on_start_end(self, sender, task_id, **kwargs):
|
|
|
|
return self.handle_task_end(task_id)
|
|
|
|
|
|
|
|
def after_task_publish(self, sender, body, **kwargs):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def emit(self, record):
|
|
|
|
task_id = self.get_current_task_id()
|
|
|
|
if not task_id:
|
|
|
|
return
|
|
|
|
try:
|
|
|
|
self.write_task_log(task_id, record)
|
|
|
|
self.flush()
|
|
|
|
except Exception:
|
|
|
|
self.handleError(record)
|
|
|
|
|
|
|
|
def write_task_log(self, task_id, msg):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def handle_task_start(self, task_id):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def handle_task_end(self, task_id):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler):
|
|
|
|
@staticmethod
|
|
|
|
def get_current_thread_id():
|
|
|
|
return str(get_ident())
|
|
|
|
|
|
|
|
def emit(self, record):
|
|
|
|
thread_id = self.get_current_thread_id()
|
|
|
|
try:
|
|
|
|
self.write_thread_task_log(thread_id, record)
|
|
|
|
self.flush()
|
|
|
|
except ValueError:
|
|
|
|
self.handleError(record)
|
|
|
|
|
|
|
|
def write_thread_task_log(self, thread_id, msg):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def handle_task_start(self, task_id):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def handle_task_end(self, task_id):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def handleError(self, record) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler):
|
|
|
|
def __init__(self):
|
|
|
|
self.producer = CeleryLoggerProducer()
|
|
|
|
super().__init__(stream=None)
|
|
|
|
|
|
|
|
def write_task_log(self, task_id, record):
|
|
|
|
msg = self.format(record)
|
|
|
|
self.producer.log(task_id, msg)
|
|
|
|
|
|
|
|
def flush(self):
|
|
|
|
self.producer.flush()
|
|
|
|
|
|
|
|
|
|
|
|
class CeleryTaskFileHandler(CeleryTaskLoggerHandler):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.f = None
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def emit(self, record):
|
|
|
|
msg = self.format(record)
|
|
|
|
if not self.f or self.f.closed:
|
|
|
|
return
|
|
|
|
self.f.write(msg)
|
|
|
|
self.f.write(self.terminator)
|
|
|
|
self.flush()
|
|
|
|
|
|
|
|
def flush(self):
|
|
|
|
self.f and self.f.flush()
|
|
|
|
|
|
|
|
def handle_task_start(self, task_id):
|
|
|
|
log_path = get_celery_task_log_path(task_id)
|
|
|
|
self.f = open(log_path, 'a')
|
|
|
|
|
|
|
|
def handle_task_end(self, task_id):
|
|
|
|
self.f and self.f.close()
|
|
|
|
|
|
|
|
|
|
|
|
class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.thread_id_fd_mapper = {}
|
|
|
|
self.task_id_thread_id_mapper = {}
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def write_thread_task_log(self, thread_id, record):
|
|
|
|
f = self.thread_id_fd_mapper.get(thread_id, None)
|
|
|
|
if not f:
|
|
|
|
raise ValueError('Not found thread task file')
|
|
|
|
msg = self.format(record)
|
|
|
|
f.write(msg.encode())
|
|
|
|
f.write(self.terminator.encode())
|
|
|
|
f.flush()
|
|
|
|
|
|
|
|
def flush(self):
|
|
|
|
for f in self.thread_id_fd_mapper.values():
|
|
|
|
f.flush()
|
|
|
|
|
|
|
|
def handle_task_start(self, task_id):
|
|
|
|
log_path = get_celery_task_log_path(task_id)
|
|
|
|
thread_id = self.get_current_thread_id()
|
|
|
|
self.task_id_thread_id_mapper[task_id] = thread_id
|
|
|
|
f = open(log_path, 'ab')
|
|
|
|
self.thread_id_fd_mapper[thread_id] = f
|
|
|
|
|
|
|
|
def handle_task_end(self, task_id):
|
|
|
|
ident_id = self.task_id_thread_id_mapper.get(task_id, '')
|
|
|
|
f = self.thread_id_fd_mapper.pop(ident_id, None)
|
|
|
|
if f and not f.closed:
|
|
|
|
f.write(CELERY_LOG_MAGIC_MARK)
|
|
|
|
f.close()
|
|
|
|
self.task_id_thread_id_mapper.pop(task_id, None)
|