diff --git a/apps/assets/api/system_user.py b/apps/assets/api/system_user.py index 27baaa017..8ec151285 100644 --- a/apps/assets/api/system_user.py +++ b/apps/assets/api/system_user.py @@ -98,8 +98,8 @@ class SystemUserTaskApi(generics.CreateAPIView): return task @staticmethod - def do_test(system_user): - task = test_system_user_connectivity_manual.delay(system_user) + def do_test(system_user, asset_ids): + task = test_system_user_connectivity_manual.delay(system_user, asset_ids) return task def get_object(self): @@ -109,16 +109,20 @@ class SystemUserTaskApi(generics.CreateAPIView): def perform_create(self, serializer): action = serializer.validated_data["action"] asset = serializer.validated_data.get('asset') - assets = serializer.validated_data.get('assets') or [] + + if asset: + assets = [asset] + else: + assets = serializer.validated_data.get('assets') or [] + + asset_ids = [asset.id for asset in assets] + asset_ids = asset_ids if asset_ids else None system_user = self.get_object() if action == 'push': - assets = [asset] if asset else assets - asset_ids = [asset.id for asset in assets] - asset_ids = asset_ids if asset_ids else None task = self.do_push(system_user, asset_ids) else: - task = self.do_test(system_user) + task = self.do_test(system_user, asset_ids) data = getattr(serializer, '_data', {}) data["task"] = task.id setattr(serializer, '_data', data) diff --git a/apps/assets/tasks/system_user_connectivity.py b/apps/assets/tasks/system_user_connectivity.py index 42b6f2331..87152c8cc 100644 --- a/apps/assets/tasks/system_user_connectivity.py +++ b/apps/assets/tasks/system_user_connectivity.py @@ -5,6 +5,7 @@ from collections import defaultdict from celery import shared_task from django.utils.translation import ugettext as _ +from assets.models import Asset from common.utils import get_logger from orgs.utils import tmp_to_org, org_aware_func from ..models import SystemUser @@ -96,9 +97,12 @@ def test_system_user_connectivity_util(system_user, assets, task_name): @shared_task(queue="ansible") @org_aware_func("system_user") -def test_system_user_connectivity_manual(system_user): +def test_system_user_connectivity_manual(system_user, asset_ids=None): task_name = _("Test system user connectivity: {}").format(system_user) - assets = system_user.get_related_assets() + if asset_ids: + assets = Asset.objects.filter(id__in=asset_ids) + else: + assets = system_user.get_related_assets() test_system_user_connectivity_util(system_user, assets, task_name)