mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			279 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			279 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
	
# -*- coding: utf-8 -*-
 | 
						|
#
 | 
						|
import os
 | 
						|
import re
 | 
						|
from collections import defaultdict
 | 
						|
 | 
						|
from celery.result import AsyncResult
 | 
						|
from django.shortcuts import get_object_or_404
 | 
						|
from django.utils.translation import gettext as _
 | 
						|
from django_celery_beat.models import PeriodicTask
 | 
						|
from django_filters import rest_framework as drf_filters
 | 
						|
from rest_framework import generics, viewsets, mixins, status
 | 
						|
from rest_framework.response import Response
 | 
						|
 | 
						|
from common.api import LogTailApi, CommonApiMixin
 | 
						|
from common.drf.filters import BaseFilterSet
 | 
						|
from common.exceptions import JMSException
 | 
						|
from common.permissions import IsValidUser
 | 
						|
from common.utils.timezone import local_now
 | 
						|
from ops.celery import app
 | 
						|
from ..ansible.utils import get_ansible_task_log_path
 | 
						|
from ..celery.utils import get_celery_task_log_path
 | 
						|
from ..models import CeleryTaskExecution, CeleryTask
 | 
						|
from ..serializers import CeleryResultSerializer, CeleryPeriodTaskSerializer
 | 
						|
from ..serializers.celery import CeleryTaskSerializer, CeleryTaskExecutionSerializer
 | 
						|
 | 
						|
__all__ = [
 | 
						|
    'CeleryTaskExecutionLogApi', 'CeleryResultApi', 'CeleryPeriodTaskViewSet',
 | 
						|
    'AnsibleTaskLogApi', 'CeleryTaskViewSet', 'CeleryTaskExecutionViewSet'
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
class CeleryTaskExecutionLogApi(LogTailApi):
 | 
						|
    permission_classes = (IsValidUser,)
 | 
						|
    task = None
 | 
						|
    task_id = ''
 | 
						|
    pattern = re.compile(r'Task .* succeeded in \d+\.\d+s.*')
 | 
						|
 | 
						|
    def get(self, request, *args, **kwargs):
 | 
						|
        self.task_id = str(kwargs.get('pk'))
 | 
						|
        self.task = AsyncResult(self.task_id)
 | 
						|
        return super().get(request, *args, **kwargs)
 | 
						|
 | 
						|
    def filter_line(self, line):
 | 
						|
        if self.pattern.match(line):
 | 
						|
            line = self.pattern.sub(line, '')
 | 
						|
        return line
 | 
						|
 | 
						|
    def get_log_path(self):
 | 
						|
        new_path = get_celery_task_log_path(self.task_id)
 | 
						|
        if new_path and os.path.isfile(new_path):
 | 
						|
            return new_path
 | 
						|
        try:
 | 
						|
            task = CeleryTaskExecution.objects.get(id=self.task_id)
 | 
						|
        except CeleryTaskExecution.DoesNotExist:
 | 
						|
            return None
 | 
						|
        return task.full_log_path
 | 
						|
 | 
						|
    def is_file_finish_write(self):
 | 
						|
        return self.task.ready()
 | 
						|
 | 
						|
    def get_no_file_message(self, request):
 | 
						|
        if self.mark == 'undefined':
 | 
						|
            return '.'
 | 
						|
        else:
 | 
						|
            return _('Waiting task start')
 | 
						|
 | 
						|
 | 
						|
class AnsibleTaskLogApi(LogTailApi):
 | 
						|
    permission_classes = (IsValidUser,)
 | 
						|
 | 
						|
    def get_log_path(self):
 | 
						|
        new_path = get_ansible_task_log_path(self.kwargs.get('pk'))
 | 
						|
        if new_path and os.path.isfile(new_path):
 | 
						|
            return new_path
 | 
						|
 | 
						|
    def get_no_file_message(self, request):
 | 
						|
        if self.mark == 'undefined':
 | 
						|
            return '.'
 | 
						|
        else:
 | 
						|
            return _('Waiting task start')
 | 
						|
 | 
						|
 | 
						|
class CeleryResultApi(generics.RetrieveAPIView):
 | 
						|
    permission_classes = (IsValidUser,)
 | 
						|
    serializer_class = CeleryResultSerializer
 | 
						|
 | 
						|
    def get_object(self):
 | 
						|
        pk = self.kwargs.get('pk')
 | 
						|
        return AsyncResult(str(pk))
 | 
						|
 | 
						|
 | 
						|
class CeleryPeriodTaskViewSet(CommonApiMixin, viewsets.ModelViewSet):
 | 
						|
    queryset = PeriodicTask.objects.all()
 | 
						|
    serializer_class = CeleryPeriodTaskSerializer
 | 
						|
    http_method_names = ('get', 'head', 'options', 'patch')
 | 
						|
    lookup_field = 'name'
 | 
						|
    lookup_value_regex = '[\w.@]+'
 | 
						|
 | 
						|
    def get_object(self):
 | 
						|
        name = self.kwargs.get('name')
 | 
						|
        obj = get_object_or_404(PeriodicTask, name=name)
 | 
						|
        return obj
 | 
						|
 | 
						|
 | 
						|
class CelerySummaryAPIView(generics.RetrieveAPIView):
 | 
						|
    def get(self, request, *args, **kwargs):
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class CeleryTaskFilterSet(BaseFilterSet):
 | 
						|
    name = drf_filters.CharFilter(method='filter_name')
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def filter_name(queryset, name, value):
 | 
						|
        _ids = []
 | 
						|
        for task in queryset:
 | 
						|
            comment = task.meta.get('comment')
 | 
						|
            if not comment:
 | 
						|
                continue
 | 
						|
            if value not in comment:
 | 
						|
                continue
 | 
						|
            _ids.append(task.id)
 | 
						|
        queryset = queryset.filter(id__in=_ids)
 | 
						|
        return queryset
 | 
						|
 | 
						|
    class Meta:
 | 
						|
        model = CeleryTask
 | 
						|
        fields = ['name']
 | 
						|
 | 
						|
 | 
						|
class CeleryTaskViewSet(
 | 
						|
    CommonApiMixin, mixins.RetrieveModelMixin,
 | 
						|
    mixins.ListModelMixin, mixins.DestroyModelMixin,
 | 
						|
    viewsets.GenericViewSet
 | 
						|
):
 | 
						|
    search_fields = ('name',)
 | 
						|
    filterset_class = CeleryTaskFilterSet
 | 
						|
    serializer_class = CeleryTaskSerializer
 | 
						|
 | 
						|
    def get_queryset(self):
 | 
						|
        return CeleryTask.objects.exclude(name__startswith='celery')
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def extract_schedule(input_string):
 | 
						|
        pattern = r'(\S+ \S+ \S+ \S+ \S+).*'
 | 
						|
        match = re.match(pattern, input_string)
 | 
						|
        if match:
 | 
						|
            return match.group(1)
 | 
						|
        else:
 | 
						|
            return input_string
 | 
						|
 | 
						|
    def generate_execute_time(self, queryset):
 | 
						|
        now = local_now()
 | 
						|
        for i in queryset:
 | 
						|
            task = getattr(i, 'periodic_obj', None)
 | 
						|
            if not task:
 | 
						|
                continue
 | 
						|
            i.exec_cycle = self.extract_schedule(str(task.scheduler))
 | 
						|
            last_run_at = task.last_run_at or now
 | 
						|
            next_run_at = task.schedule.remaining_estimate(last_run_at)
 | 
						|
            if next_run_at.total_seconds() < 0:
 | 
						|
                next_run_at = task.schedule.remaining_estimate(now)
 | 
						|
            i.next_exec_time = now + next_run_at
 | 
						|
            i.enabled = task.enabled
 | 
						|
        return queryset
 | 
						|
 | 
						|
    def generate_summary_state(self, execution_qs):
 | 
						|
        model = self.get_queryset().model
 | 
						|
        executions = execution_qs.order_by('-date_published').values('name', 'state')
 | 
						|
        summary_state_dict = defaultdict(
 | 
						|
            lambda: {
 | 
						|
                'states': [], 'state': 'green',
 | 
						|
                'summary': {'total': 0, 'success': 0}
 | 
						|
            }
 | 
						|
        )
 | 
						|
        for execution in executions:
 | 
						|
            name = execution['name']
 | 
						|
            state = execution['state']
 | 
						|
 | 
						|
            summary = summary_state_dict[name]['summary']
 | 
						|
 | 
						|
            summary['total'] += 1
 | 
						|
            summary['success'] += 1 if state == 'SUCCESS' else 0
 | 
						|
 | 
						|
            states = summary_state_dict[name].get('states')
 | 
						|
            if states is not None and len(states) >= 5:
 | 
						|
                color = model.compute_state_color(states)
 | 
						|
                summary_state_dict[name]['state'] = color
 | 
						|
                summary_state_dict[name].pop('states', None)
 | 
						|
            elif isinstance(states, list):
 | 
						|
                states.append(state)
 | 
						|
 | 
						|
        return summary_state_dict
 | 
						|
 | 
						|
    def loading_summary_state(self, queryset):
 | 
						|
        if isinstance(queryset, list):
 | 
						|
            names = [i.name for i in queryset]
 | 
						|
            execution_qs = CeleryTaskExecution.objects.filter(name__in=names)
 | 
						|
        else:
 | 
						|
            execution_qs = CeleryTaskExecution.objects.all()
 | 
						|
        summary_state_dict = self.generate_summary_state(execution_qs)
 | 
						|
        for i in queryset:
 | 
						|
            i.summary = summary_state_dict.get(i.name, {}).get('summary', {})
 | 
						|
            i.state = summary_state_dict.get(i.name, {}).get('state', 'green')
 | 
						|
        return queryset
 | 
						|
 | 
						|
    def filter_queryset(self, queryset):
 | 
						|
        search = self.request.query_params.get('search')
 | 
						|
        if search:
 | 
						|
            queryset = CeleryTaskFilterSet.filter_name(queryset, 'name', search)
 | 
						|
        else:
 | 
						|
            queryset = super().filter_queryset(queryset)
 | 
						|
        return queryset
 | 
						|
 | 
						|
    def list(self, request, *args, **kwargs):
 | 
						|
        queryset = self.filter_queryset(self.get_queryset())
 | 
						|
        queryset = self.mark_periodic_and_sorted(queryset)
 | 
						|
 | 
						|
        page = self.paginate_queryset(queryset)
 | 
						|
        if page is not None:
 | 
						|
            page = self.generate_execute_time(page)
 | 
						|
            page = self.loading_summary_state(page)
 | 
						|
            serializer = self.get_serializer(page, many=True)
 | 
						|
            return self.get_paginated_response(serializer.data)
 | 
						|
 | 
						|
        queryset = self.generate_execute_time(queryset)
 | 
						|
        queryset = self.loading_summary_state(queryset)
 | 
						|
        serializer = self.get_serializer(queryset, many=True)
 | 
						|
        return Response(serializer.data)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def mark_periodic_and_sorted(queryset):
 | 
						|
        names = queryset.values_list('name', flat=True)
 | 
						|
        periodic_tasks = PeriodicTask.objects.filter(name__in=names)
 | 
						|
        periodic_task_dict = {task.task: task for task in periodic_tasks}
 | 
						|
        for q in queryset:
 | 
						|
            if q.name in periodic_task_dict:
 | 
						|
                q.periodic_obj = periodic_task_dict[q.name]
 | 
						|
                q.is_periodic = True
 | 
						|
            else:
 | 
						|
                q.is_periodic = False
 | 
						|
        queryset = sorted(queryset, key=lambda x: x.is_periodic, reverse=True)
 | 
						|
        return queryset
 | 
						|
 | 
						|
 | 
						|
class CeleryTaskExecutionViewSet(CommonApiMixin, viewsets.ModelViewSet):
 | 
						|
    serializer_class = CeleryTaskExecutionSerializer
 | 
						|
    http_method_names = ('get', 'post', 'head', 'options',)
 | 
						|
    queryset = CeleryTaskExecution.objects.all()
 | 
						|
    search_fields = ('id',)
 | 
						|
 | 
						|
    def get_queryset(self):
 | 
						|
        task_id = self.request.query_params.get('task_id')
 | 
						|
        if task_id:
 | 
						|
            task = get_object_or_404(CeleryTask, id=task_id)
 | 
						|
            self.queryset = self.queryset.filter(name=task.name)
 | 
						|
        if not self.request.user.is_superuser:
 | 
						|
            self.queryset = self.queryset.filter(creator=self.request.user)
 | 
						|
        return self.queryset
 | 
						|
 | 
						|
    def create(self, request, *args, **kwargs):
 | 
						|
        form_id = self.request.query_params.get('from', None)
 | 
						|
        if not form_id:
 | 
						|
            return Response(status=status.HTTP_400_BAD_REQUEST)
 | 
						|
        execution = get_object_or_404(CeleryTaskExecution, id=form_id)
 | 
						|
        task = app.tasks.get(execution.name, None)
 | 
						|
        if not task:
 | 
						|
            msg = _("Task {} not found").format(execution.name)
 | 
						|
            raise JMSException(code='task_not_found_error', detail=msg)
 | 
						|
        try:
 | 
						|
            execution.kwargs.pop('__current_lang', None)
 | 
						|
            execution.kwargs.pop('__current_org_id', None)
 | 
						|
            t = task.delay(*execution.args, **execution.kwargs)
 | 
						|
        except TypeError:
 | 
						|
            msg = _("Task {} args or kwargs error").format(execution.name)
 | 
						|
            raise JMSException(code='task_args_error', detail=msg)
 | 
						|
        return Response(status=status.HTTP_201_CREATED, data={'task_id': t.id})
 |