mirror of https://github.com/jumpserver/jumpserver
119 lines
4.2 KiB
Python
119 lines
4.2 KiB
Python
|
from django.shortcuts import get_object_or_404
|
||
|
from django.utils.translation import ugettext_lazy as _
|
||
|
from rest_framework.response import Response
|
||
|
from rest_framework import status, mixins, viewsets
|
||
|
|
||
|
from orgs.mixins import generics
|
||
|
from assets import serializers
|
||
|
from assets.const import AutomationTypes
|
||
|
from assets.tasks import execute_automation
|
||
|
from assets.models import BaseAutomation, AutomationExecution
|
||
|
from common.const.choices import Trigger
|
||
|
|
||
|
__all__ = [
|
||
|
'AutomationAssetsListApi', 'AutomationRemoveAssetApi',
|
||
|
'AutomationAddAssetApi', 'AutomationNodeAddRemoveApi', 'AutomationExecutionViewSet'
|
||
|
]
|
||
|
|
||
|
|
||
|
class AutomationAssetsListApi(generics.ListAPIView):
|
||
|
serializer_class = serializers.AutomationAssetsSerializer
|
||
|
filter_fields = ("name", "address")
|
||
|
search_fields = filter_fields
|
||
|
|
||
|
def get_object(self):
|
||
|
pk = self.kwargs.get('pk')
|
||
|
return get_object_or_404(BaseAutomation, pk=pk)
|
||
|
|
||
|
def get_queryset(self):
|
||
|
instance = self.get_object()
|
||
|
assets = instance.get_all_assets().only(
|
||
|
*self.serializer_class.Meta.only_fields
|
||
|
)
|
||
|
return assets
|
||
|
|
||
|
|
||
|
class AutomationRemoveAssetApi(generics.RetrieveUpdateAPIView):
|
||
|
model = BaseAutomation
|
||
|
serializer_class = serializers.UpdateAssetSerializer
|
||
|
|
||
|
def update(self, request, *args, **kwargs):
|
||
|
instance = self.get_object()
|
||
|
serializer = self.serializer_class(data=request.data)
|
||
|
|
||
|
if not serializer.is_valid():
|
||
|
return Response({'error': serializer.errors})
|
||
|
|
||
|
assets = serializer.validated_data.get('assets')
|
||
|
if assets:
|
||
|
instance.assets.remove(*tuple(assets))
|
||
|
return Response({'msg': 'ok'})
|
||
|
|
||
|
|
||
|
class AutomationAddAssetApi(generics.RetrieveUpdateAPIView):
|
||
|
model = BaseAutomation
|
||
|
serializer_class = serializers.UpdateAssetSerializer
|
||
|
|
||
|
def update(self, request, *args, **kwargs):
|
||
|
instance = self.get_object()
|
||
|
serializer = self.serializer_class(data=request.data)
|
||
|
if serializer.is_valid():
|
||
|
assets = serializer.validated_data.get('assets')
|
||
|
if assets:
|
||
|
instance.assets.add(*tuple(assets))
|
||
|
return Response({"msg": "ok"})
|
||
|
else:
|
||
|
return Response({"error": serializer.errors})
|
||
|
|
||
|
|
||
|
class AutomationNodeAddRemoveApi(generics.RetrieveUpdateAPIView):
|
||
|
model = BaseAutomation
|
||
|
serializer_class = serializers.UpdateAssetSerializer
|
||
|
|
||
|
def update(self, request, *args, **kwargs):
|
||
|
action_params = ['add', 'remove']
|
||
|
action = request.query_params.get('action')
|
||
|
if action not in action_params:
|
||
|
err_info = _("The parameter 'action' must be [{}]".format(','.join(action_params)))
|
||
|
return Response({"error": err_info})
|
||
|
|
||
|
instance = self.get_object()
|
||
|
serializer = self.serializer_class(data=request.data)
|
||
|
if serializer.is_valid():
|
||
|
nodes = serializer.validated_data.get('nodes')
|
||
|
if nodes:
|
||
|
# eg: plan.nodes.add(*tuple(assets))
|
||
|
getattr(instance.nodes, action)(*tuple(nodes))
|
||
|
return Response({"msg": "ok"})
|
||
|
else:
|
||
|
return Response({"error": serializer.errors})
|
||
|
|
||
|
|
||
|
class AutomationExecutionViewSet(
|
||
|
mixins.CreateModelMixin, mixins.ListModelMixin,
|
||
|
mixins.RetrieveModelMixin, viewsets.GenericViewSet
|
||
|
):
|
||
|
search_fields = ('trigger',)
|
||
|
filterset_fields = ('trigger', 'automation_id')
|
||
|
serializer_class = serializers.AutomationExecutionSerializer
|
||
|
|
||
|
def get_queryset(self):
|
||
|
queryset = AutomationExecution.objects.all()
|
||
|
return queryset
|
||
|
|
||
|
def filter_queryset(self, queryset):
|
||
|
queryset = super().filter_queryset(queryset)
|
||
|
queryset = queryset.order_by('-date_start')
|
||
|
return queryset
|
||
|
|
||
|
def create(self, request, *args, **kwargs):
|
||
|
serializer = self.get_serializer(data=request.data)
|
||
|
serializer.is_valid(raise_exception=True)
|
||
|
automation = serializer.validated_data.get('automation')
|
||
|
tp = serializer.validated_data.get('type')
|
||
|
model = AutomationTypes.get_model(tp)
|
||
|
task = execute_automation.delay(
|
||
|
pid=automation.ok, trigger=Trigger.manual, model=model
|
||
|
)
|
||
|
return Response({'task': task.id}, status=status.HTTP_201_CREATED)
|