You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

124 lines
4.5 KiB

import asyncio
import os
import aiofiles
from asgiref.sync import sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from common.db.utils import close_old_connections
from common.utils import get_logger
from rbac.builtin import BuiltinRole
from .ansible.utils import get_ansible_task_log_path
from .celery.utils import get_celery_task_log_path
from .const import CELERY_LOG_MAGIC_MARK
from .models import CeleryTaskExecution
logger = get_logger(__name__)
class TaskLogWebsocket(AsyncJsonWebsocketConsumer):
disconnected = False
user_tasks = (
'ops.tasks.run_ops_job',
'ops.tasks.run_ops_job_execution',
)
log_types = {
'celery': get_celery_task_log_path,
'ansible': get_ansible_task_log_path
}
async def connect(self):
user = self.scope["user"]
if user.is_authenticated:
await self.accept()
else:
await self.close()
def get_log_path(self, task_id, log_type):
func = self.log_types.get(log_type)
if func:
return func(task_id)
@sync_to_async
def get_task(self, task_id):
task = CeleryTaskExecution.objects.filter(id=task_id).first()
# task.creator 是 foreign key, 会异步去查询的,在下面的 if task.creator 会报错, 所以这里先取出来
if task and task.creator != ' ':
return task
else:
return None
@sync_to_async
def get_current_user_role_ids(self, user):
roles = user.system_roles.all() | user.org_roles.all()
user_role_ids = set(map(str, roles.values_list('id', flat=True)))
return user_role_ids
async def receive_json(self, content, **kwargs):
task_id = content.get('task')
task = await self.get_task(task_id)
if not task:
await self.send_json({'message': 'Task not found', 'task': task_id})
return
admin_auditor_role_ids = {
BuiltinRole.system_admin.id,
BuiltinRole.system_auditor.id,
BuiltinRole.org_admin.id,
BuiltinRole.org_auditor.id
}
user = self.scope['user']
user_role_ids = await self.get_current_user_role_ids(user)
has_admin_auditor_role = bool(admin_auditor_role_ids & user_role_ids)
if not has_admin_auditor_role and task.name in self.user_tasks and task.creator != user:
await self.send_json({'message': 'No permission', 'task': task_id})
return
task_type = content.get('type', 'celery')
log_path = self.get_log_path(task_id, task_type)
await self.async_handle_task(task_id, log_path)
async def async_handle_task(self, task_id, log_path):
logger.info("Task id: {}".format(task_id))
timeout = 0
while not self.disconnected:
if timeout >= 60:
await self.send_json({'message': '\r\n', 'task': task_id})
await self.send_json({'message': 'Task log was not found, the directory may not be shared.',
'task': task_id})
break
if not os.path.exists(log_path):
await self.send_json({'message': '.', 'task': task_id})
timeout += 0.5
await asyncio.sleep(0.5)
else:
await self.send_task_log(task_id, log_path)
break
async def send_task_log(self, task_id, log_path):
await self.send_json({'message': '\r\n'})
try:
logger.debug('Task log path: {}'.format(log_path))
async with aiofiles.open(log_path, 'rb') as task_log_f:
while not self.disconnected:
data = await task_log_f.read(4096)
if data:
data = data.replace(b'\n', b'\r\n')
await self.send_json(
{'message': data.decode(errors='ignore'), 'task': task_id}
)
if data.find(CELERY_LOG_MAGIC_MARK) != -1:
await self.send_json(
{'event': 'end', 'task': task_id, 'message': ''}
)
logger.debug("Task log file magic mark found")
break
await asyncio.sleep(0.2)
except OSError as e:
logger.warn('Task log path open failed: {}'.format(e))
async def disconnect(self, close_code):
self.disconnected = True
close_old_connections()