You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
jumpserver/apps/ops/api/celery.py

170 lines
5.5 KiB

# -*- coding: utf-8 -*-
#
import os
import re
from celery.result import AsyncResult
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _
from django_celery_beat.models import PeriodicTask
from rest_framework import generics, viewsets, mixins, status
from rest_framework.response import Response
from django_filters import rest_framework as drf_filters
from common.api import LogTailApi, CommonApiMixin
from common.exceptions import JMSException
from common.permissions import IsValidUser
from common.drf.filters import BaseFilterSet
from ops.celery import app
from ..ansible.utils import get_ansible_task_log_path
from ..celery.utils import get_celery_task_log_path
from ..models import CeleryTaskExecution, CeleryTask
from ..serializers import CeleryResultSerializer, CeleryPeriodTaskSerializer
from ..serializers.celery import CeleryTaskSerializer, CeleryTaskExecutionSerializer
__all__ = [
'CeleryTaskExecutionLogApi', 'CeleryResultApi', 'CeleryPeriodTaskViewSet',
'AnsibleTaskLogApi', 'CeleryTaskViewSet', 'CeleryTaskExecutionViewSet'
]
class CeleryTaskExecutionLogApi(LogTailApi):
permission_classes = (IsValidUser,)
task = None
task_id = ''
pattern = re.compile(r'Task .* succeeded in \d+\.\d+s.*')
def get(self, request, *args, **kwargs):
self.task_id = str(kwargs.get('pk'))
self.task = AsyncResult(self.task_id)
return super().get(request, *args, **kwargs)
def filter_line(self, line):
if self.pattern.match(line):
line = self.pattern.sub(line, '')
return line
def get_log_path(self):
new_path = get_celery_task_log_path(self.task_id)
if new_path and os.path.isfile(new_path):
return new_path
try:
task = CeleryTaskExecution.objects.get(id=self.task_id)
except CeleryTaskExecution.DoesNotExist:
return None
return task.full_log_path
def is_file_finish_write(self):
return self.task.ready()
def get_no_file_message(self, request):
if self.mark == 'undefined':
return '.'
else:
return _('Waiting task start')
class AnsibleTaskLogApi(LogTailApi):
permission_classes = (IsValidUser,)
def get_log_path(self):
new_path = get_ansible_task_log_path(self.kwargs.get('pk'))
if new_path and os.path.isfile(new_path):
return new_path
def get_no_file_message(self, request):
if self.mark == 'undefined':
return '.'
else:
return _('Waiting task start')
class CeleryResultApi(generics.RetrieveAPIView):
permission_classes = (IsValidUser,)
serializer_class = CeleryResultSerializer
def get_object(self):
pk = self.kwargs.get('pk')
return AsyncResult(str(pk))
class CeleryPeriodTaskViewSet(CommonApiMixin, viewsets.ModelViewSet):
queryset = PeriodicTask.objects.all()
serializer_class = CeleryPeriodTaskSerializer
http_method_names = ('get', 'head', 'options', 'patch')
def get_queryset(self):
queryset = super().get_queryset()
queryset = queryset.exclude(description='')
return queryset
class CelerySummaryAPIView(generics.RetrieveAPIView):
def get(self, request, *args, **kwargs):
pass
class CeleryTaskFilterSet(BaseFilterSet):
id = drf_filters.CharFilter(field_name='id', lookup_expr='exact')
name = drf_filters.CharFilter(method='filter_name')
@staticmethod
def filter_name(queryset, name, value):
_ids = []
for task in queryset:
comment = task.meta.get('comment')
if not comment:
continue
if value not in comment:
continue
_ids.append(task.id)
queryset = queryset.filter(id__in=_ids)
return queryset
class Meta:
model = CeleryTask
fields = ['id', 'name']
class CeleryTaskViewSet(
CommonApiMixin, mixins.RetrieveModelMixin,
mixins.ListModelMixin, mixins.DestroyModelMixin,
viewsets.GenericViewSet
):
search_fields = ('id', 'name')
filterset_class = CeleryTaskFilterSet
serializer_class = CeleryTaskSerializer
def get_queryset(self):
return CeleryTask.objects.exclude(name__startswith='celery')
class CeleryTaskExecutionViewSet(CommonApiMixin, viewsets.ModelViewSet):
serializer_class = CeleryTaskExecutionSerializer
http_method_names = ('get', 'post', 'head', 'options',)
queryset = CeleryTaskExecution.objects.all()
search_fields = ('name',)
def get_queryset(self):
task_id = self.request.query_params.get('task_id')
if task_id:
task = get_object_or_404(CeleryTask, id=task_id)
self.queryset = self.queryset.filter(name=task.name)
return self.queryset
def create(self, request, *args, **kwargs):
form_id = self.request.query_params.get('from', None)
if not form_id:
return Response(status=status.HTTP_400_BAD_REQUEST)
execution = get_object_or_404(CeleryTaskExecution, id=form_id)
task = app.tasks.get(execution.name, None)
if not task:
msg = _("Task {} not found").format(execution.name)
raise JMSException(code='task_not_found_error', detail=msg)
try:
t = task.delay(*execution.args, **execution.kwargs)
except TypeError:
msg = _("Task {} args or kwargs error").format(execution.name)
raise JMSException(code='task_args_error', detail=msg)
return Response(status=status.HTTP_201_CREATED, data={'task_id': t.id})