mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			116 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			116 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
| from django.shortcuts import get_object_or_404
 | |
| from django.utils.translation import ugettext_lazy as _
 | |
| from rest_framework import status, mixins, viewsets
 | |
| from rest_framework.response import Response
 | |
| 
 | |
| from assets import serializers
 | |
| from assets.models import BaseAutomation
 | |
| from accounts.tasks import execute_automation
 | |
| from accounts.models import AutomationExecution
 | |
| from common.const.choices import Trigger
 | |
| from orgs.mixins import generics
 | |
| 
 | |
| __all__ = [
 | |
|     'AutomationAssetsListApi', 'AutomationRemoveAssetApi',
 | |
|     'AutomationAddAssetApi', 'AutomationNodeAddRemoveApi',
 | |
|     'AutomationExecutionViewSet',
 | |
| ]
 | |
| 
 | |
| 
 | |
| class AutomationAssetsListApi(generics.ListAPIView):
 | |
|     model = BaseAutomation
 | |
|     serializer_class = serializers.AutomationAssetsSerializer
 | |
|     filter_fields = ("name", "address")
 | |
|     search_fields = filter_fields
 | |
| 
 | |
|     def get_object(self):
 | |
|         pk = self.kwargs.get('pk')
 | |
|         return get_object_or_404(self.model, pk=pk)
 | |
| 
 | |
|     def get_queryset(self):
 | |
|         instance = self.get_object()
 | |
|         assets = instance.get_all_assets().only(
 | |
|             *self.serializer_class.Meta.only_fields
 | |
|         )
 | |
|         return assets
 | |
| 
 | |
| 
 | |
| class AutomationRemoveAssetApi(generics.RetrieveUpdateAPIView):
 | |
|     model = BaseAutomation
 | |
|     serializer_class = serializers.UpdateAssetSerializer
 | |
| 
 | |
|     def update(self, request, *args, **kwargs):
 | |
|         instance = self.get_object()
 | |
|         serializer = self.serializer_class(data=request.data)
 | |
| 
 | |
|         if not serializer.is_valid():
 | |
|             return Response({'error': serializer.errors})
 | |
| 
 | |
|         assets = serializer.validated_data.get('assets')
 | |
|         if assets:
 | |
|             instance.assets.remove(*tuple(assets))
 | |
|         return Response({'msg': 'ok'})
 | |
| 
 | |
| 
 | |
| class AutomationAddAssetApi(generics.RetrieveUpdateAPIView):
 | |
|     model = BaseAutomation
 | |
|     serializer_class = serializers.UpdateAssetSerializer
 | |
| 
 | |
|     def update(self, request, *args, **kwargs):
 | |
|         instance = self.get_object()
 | |
|         serializer = self.serializer_class(data=request.data)
 | |
|         if serializer.is_valid():
 | |
|             assets = serializer.validated_data.get('assets')
 | |
|             if assets:
 | |
|                 instance.assets.add(*tuple(assets))
 | |
|             return Response({"msg": "ok"})
 | |
|         else:
 | |
|             return Response({"error": serializer.errors})
 | |
| 
 | |
| 
 | |
| class AutomationNodeAddRemoveApi(generics.RetrieveUpdateAPIView):
 | |
|     model = BaseAutomation
 | |
|     serializer_class = serializers.UpdateNodeSerializer
 | |
| 
 | |
|     def update(self, request, *args, **kwargs):
 | |
|         action_params = ['add', 'remove']
 | |
|         action = request.query_params.get('action')
 | |
|         if action not in action_params:
 | |
|             err_info = _("The parameter 'action' must be [{}]".format(','.join(action_params)))
 | |
|             return Response({"error": err_info})
 | |
| 
 | |
|         instance = self.get_object()
 | |
|         serializer = self.serializer_class(data=request.data)
 | |
|         if serializer.is_valid():
 | |
|             nodes = serializer.validated_data.get('nodes')
 | |
|             if nodes:
 | |
|                 # eg: plan.nodes.add(*tuple(assets))
 | |
|                 getattr(instance.nodes, action)(*tuple(nodes))
 | |
|             return Response({"msg": "ok"})
 | |
|         else:
 | |
|             return Response({"error": serializer.errors})
 | |
| 
 | |
| 
 | |
| class AutomationExecutionViewSet(
 | |
|     mixins.CreateModelMixin, mixins.ListModelMixin,
 | |
|     mixins.RetrieveModelMixin, viewsets.GenericViewSet
 | |
| ):
 | |
|     search_fields = ('trigger',)
 | |
|     filterset_fields = ('trigger', 'automation_id')
 | |
|     serializer_class = serializers.AutomationExecutionSerializer
 | |
| 
 | |
|     tp: str
 | |
| 
 | |
|     def get_queryset(self):
 | |
|         queryset = AutomationExecution.objects.all()
 | |
|         return queryset
 | |
| 
 | |
|     def create(self, request, *args, **kwargs):
 | |
|         serializer = self.get_serializer(data=request.data)
 | |
|         serializer.is_valid(raise_exception=True)
 | |
|         automation = serializer.validated_data.get('automation')
 | |
|         task = execute_automation.delay(
 | |
|             pid=automation.pk, trigger=Trigger.manual, tp=self.tp
 | |
|         )
 | |
|         return Response({'task': task.id}, status=status.HTTP_201_CREATED)
 |