diff --git a/apps/assets/api/asset.py b/apps/assets/api/asset.py index 7ce3b1f8e..d36246dd9 100644 --- a/apps/assets/api/asset.py +++ b/apps/assets/api/asset.py @@ -3,6 +3,8 @@ from assets.api import FilterAssetByNodeMixin from rest_framework.viewsets import ModelViewSet from rest_framework.generics import RetrieveAPIView +from rest_framework.response import Response +from rest_framework import status from django.shortcuts import get_object_or_404 from common.utils import get_logger, get_object_or_none @@ -12,7 +14,7 @@ from orgs.mixins import generics from ..models import Asset, Node, Platform from .. import serializers from ..tasks import ( - update_asset_hardware_info_manual, test_asset_connectivity_manual + update_assets_hardware_info_manual, test_assets_connectivity_manual ) from ..filters import FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend @@ -21,7 +23,7 @@ logger = get_logger(__file__) __all__ = [ 'AssetViewSet', 'AssetPlatformRetrieveApi', 'AssetGatewayListApi', 'AssetPlatformViewSet', - 'AssetTaskCreateApi', + 'AssetTaskCreateApi', 'AssetsTaskCreateApi', ] @@ -90,26 +92,38 @@ class AssetPlatformViewSet(ModelViewSet): return super().check_object_permissions(request, obj) -class AssetTaskCreateApi(generics.CreateAPIView): +class AssetsTaskMixin: + def perform_assets_task(self, serializer): + data = serializer.validated_data + assets = data['assets'] + action = data['action'] + if action == "refresh": + task = update_assets_hardware_info_manual.delay(assets) + else: + task = test_assets_connectivity_manual.delay(assets) + data = getattr(serializer, '_data', {}) + data["task"] = task.id + setattr(serializer, '_data', data) + + def perform_create(self, serializer): + self.perform_assets_task(serializer) + + +class AssetTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView): model = Asset serializer_class = serializers.AssetTaskSerializer permission_classes = (IsOrgAdmin,) - def get_object(self): - pk = self.kwargs.get("pk") - instance = get_object_or_404(Asset, pk=pk) - return instance + def create(self, request, *args, **kwargs): + pk = self.kwargs.get('pk') + request.data['assets'] = [pk] + return super().create(request, *args, **kwargs) - def perform_create(self, serializer): - asset = self.get_object() - action = serializer.validated_data["action"] - if action == "refresh": - task = update_asset_hardware_info_manual.delay(asset) - else: - task = test_asset_connectivity_manual.delay(asset) - data = getattr(serializer, '_data', {}) - data["task"] = task.id - setattr(serializer, '_data', data) + +class AssetsTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView): + model = Asset + serializer_class = serializers.AssetTaskSerializer + permission_classes = (IsOrgAdmin,) class AssetGatewayListApi(generics.ListAPIView): diff --git a/apps/assets/serializers/asset.py b/apps/assets/serializers/asset.py index ce8b54ca9..66effa20b 100644 --- a/apps/assets/serializers/asset.py +++ b/apps/assets/serializers/asset.py @@ -204,3 +204,6 @@ class AssetTaskSerializer(serializers.Serializer): ) task = serializers.CharField(read_only=True) action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True) + assets = serializers.PrimaryKeyRelatedField( + queryset=Asset.objects, required=False, allow_empty=True, many=True + ) diff --git a/apps/assets/tasks/asset_connectivity.py b/apps/assets/tasks/asset_connectivity.py index 8c02d0db1..ea4b90ea6 100644 --- a/apps/assets/tasks/asset_connectivity.py +++ b/apps/assets/tasks/asset_connectivity.py @@ -14,7 +14,7 @@ from .utils import clean_ansible_task_hosts, group_asset_by_platform logger = get_logger(__file__) __all__ = [ 'test_asset_connectivity_util', 'test_asset_connectivity_manual', - 'test_node_assets_connectivity_manual', + 'test_node_assets_connectivity_manual', 'test_assets_connectivity_manual', ] @@ -82,6 +82,17 @@ def test_asset_connectivity_manual(asset): return True, "" +@shared_task(queue="ansible") +def test_assets_connectivity_manual(assets): + task_name = _("Test assets connectivity: {}").format([asset.hostname for asset in assets]) + summary = test_asset_connectivity_util(assets, task_name=task_name) + + if summary.get('dark'): + return False, summary['dark'] + else: + return True, "" + + @shared_task(queue="ansible") def test_node_assets_connectivity_manual(node): task_name = _("Test if the assets under the node are connectable: {}".format(node.name)) diff --git a/apps/assets/tasks/gather_asset_hardware_info.py b/apps/assets/tasks/gather_asset_hardware_info.py index aa79e1655..daad3d694 100644 --- a/apps/assets/tasks/gather_asset_hardware_info.py +++ b/apps/assets/tasks/gather_asset_hardware_info.py @@ -19,6 +19,7 @@ disk_pattern = re.compile(r'^hd|sd|xvd|vd|nv') __all__ = [ 'update_assets_hardware_info_util', 'update_asset_hardware_info_manual', 'update_assets_hardware_info_period', 'update_node_assets_hardware_info_manual', + 'update_assets_hardware_info_manual', ] @@ -114,6 +115,12 @@ def update_asset_hardware_info_manual(asset): update_assets_hardware_info_util([asset], task_name=task_name) +@shared_task(queue="ansible") +def update_assets_hardware_info_manual(assets): + task_name = _("Update assets hardware info: {}").format([asset.hostname for asset in assets]) + update_assets_hardware_info_util(assets, task_name=task_name) + + @shared_task(queue="ansible") def update_assets_hardware_info_period(): """ diff --git a/apps/assets/urls/api_urls.py b/apps/assets/urls/api_urls.py index d81520b65..707a8e73d 100644 --- a/apps/assets/urls/api_urls.py +++ b/apps/assets/urls/api_urls.py @@ -36,6 +36,7 @@ urlpatterns = [ path('assets//gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'), path('assets//platform/', api.AssetPlatformRetrieveApi.as_view(), name='asset-platform-detail'), path('assets//tasks/', api.AssetTaskCreateApi.as_view(), name='asset-task-create'), + path('assets/tasks/', api.AssetsTaskCreateApi.as_view(), name='assets-task-create'), path('asset-users/tasks/', api.AssetUserTaskCreateAPI.as_view(), name='asset-user-task-create'),