From 5c1acae4c5da0cb68a318e9b0a24f55da63e3568 Mon Sep 17 00:00:00 2001
From: fit2bot <68588906+fit2bot@users.noreply.github.com>
Date: Wed, 8 Feb 2023 15:36:45 +0800
Subject: [PATCH] perf: push account ssh (#9467)

Co-authored-by: feng <1304903146@qq.com>
---
 apps/accounts/automations/base/manager.py     |  4 +-
 .../automations/change_secret/manager.py      | 14 ++--
 .../automations/push_account/manager.py       | 79 ++++++++++++++++---
 .../automations/verify_account/manager.py     |  4 +-
 4 files changed, 79 insertions(+), 22 deletions(-)

diff --git a/apps/accounts/automations/base/manager.py b/apps/accounts/automations/base/manager.py
index d2e96c7a6..2dd91d794 100644
--- a/apps/accounts/automations/base/manager.py
+++ b/apps/accounts/automations/base/manager.py
@@ -1,14 +1,14 @@
 from copy import deepcopy
 
 from common.utils import get_logger
-from accounts.const import AutomationTypes, SecretType
+from accounts.const import SecretType
 from assets.automations.base.manager import BasePlaybookManager
 from accounts.automations.methods import platform_automation_methods
 
 logger = get_logger(__name__)
 
 
-class PushOrVerifyHostCallbackMixin:
+class VerifyHostCallbackMixin:
     execution: callable
     get_accounts: callable
     host_account_mapper: dict
diff --git a/apps/accounts/automations/change_secret/manager.py b/apps/accounts/automations/change_secret/manager.py
index eb2d9d185..411506fe7 100644
--- a/apps/accounts/automations/change_secret/manager.py
+++ b/apps/accounts/automations/change_secret/manager.py
@@ -33,18 +33,12 @@ class ChangeSecretManager(AccountBasePlaybookManager):
             'ssh_key_change_strategy', SSHKeyStrategy.add
         )
         self.snapshot_account_usernames = self.execution.snapshot['accounts']
-        self._password_generated = None
-        self._ssh_key_generated = None
         self.name_recorder_mapper = {}  # 做个映射,方便后面处理
 
     @classmethod
     def method_type(cls):
         return AutomationTypes.change_secret
 
-    @lazyproperty
-    def related_accounts(self):
-        pass
-
     def get_kwargs(self, account, secret):
         kwargs = {}
         if self.secret_type != SecretType.SSH_KEY:
@@ -152,12 +146,16 @@ class ChangeSecretManager(AccountBasePlaybookManager):
     def on_runner_failed(self, runner, e):
         logger.error("Change secret error: ", e)
 
-    def run(self, *args, **kwargs):
+    def check_secret(self):
         if self.secret_strategy == SecretStrategy.custom \
                 and not self.execution.snapshot['secret']:
             print('Custom secret is empty')
-            return
+            return False
+        return True
 
+    def run(self, *args, **kwargs):
+        if not self.check_secret():
+            return
         super().run(*args, **kwargs)
         recorders = self.name_recorder_mapper.values()
         recorders = list(recorders)
diff --git a/apps/accounts/automations/push_account/manager.py b/apps/accounts/automations/push_account/manager.py
index 293a42967..4eb047ed0 100644
--- a/apps/accounts/automations/push_account/manager.py
+++ b/apps/accounts/automations/push_account/manager.py
@@ -1,26 +1,23 @@
+from copy import deepcopy
+
 from django.db.models import QuerySet
 
 from common.utils import get_logger
-from accounts.const import AutomationTypes
 from accounts.models import Account
-from ..base.manager import PushOrVerifyHostCallbackMixin, AccountBasePlaybookManager
+from accounts.const import AutomationTypes, SecretType
+from ..base.manager import AccountBasePlaybookManager
+from ..change_secret.manager import ChangeSecretManager
 
 logger = get_logger(__name__)
 
 
-class PushAccountManager(PushOrVerifyHostCallbackMixin, AccountBasePlaybookManager):
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.secret_type = self.execution.snapshot['secret_type']
-        self.host_account_mapper = {}
+class PushAccountManager(ChangeSecretManager, AccountBasePlaybookManager):
 
     @classmethod
     def method_type(cls):
         return AutomationTypes.push_account
 
     def create_nonlocal_accounts(self, accounts, snapshot_account_usernames, asset):
-        secret = self.execution.snapshot['secret']
         secret_type = self.secret_type
         usernames = accounts.filter(secret_type=secret_type).values_list(
             'username', flat=True
@@ -29,7 +26,7 @@ class PushAccountManager(PushOrVerifyHostCallbackMixin, AccountBasePlaybookManag
         create_account_objs = [
             Account(
                 name=f'{username}-{secret_type}', username=username,
-                secret=secret, secret_type=secret_type, asset=asset,
+                secret_type=secret_type, asset=asset,
             )
             for username in create_usernames
         ]
@@ -50,6 +47,68 @@ class PushAccountManager(PushOrVerifyHostCallbackMixin, AccountBasePlaybookManag
         )
         return accounts
 
+    def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs):
+        host = super(ChangeSecretManager, self).host_callback(
+            host, asset=asset, account=account, automation=automation,
+            path_dir=path_dir, **kwargs
+        )
+        if host.get('error'):
+            return host
+
+        accounts = asset.accounts.all()
+        accounts = self.get_accounts(account, accounts)
+
+        inventory_hosts = []
+        host['secret_type'] = self.secret_type
+        for account in accounts:
+            h = deepcopy(host)
+            h['name'] += '_' + account.username
+            new_secret = self.get_secret()
+
+            private_key_path = None
+            if self.secret_type == SecretType.SSH_KEY:
+                private_key_path = self.generate_private_key_path(new_secret, path_dir)
+                new_secret = self.generate_public_key(new_secret)
+
+            self.name_recorder_mapper[h['name']] = {
+                'account': account, 'new_secret': new_secret,
+            }
+
+            h['kwargs'] = self.get_kwargs(account, new_secret)
+            h['account'] = {
+                'name': account.name,
+                'username': account.username,
+                'secret_type': account.secret_type,
+                'secret': new_secret,
+                'private_key_path': private_key_path
+            }
+            if asset.platform.type == 'oracle':
+                h['account']['mode'] = 'sysdba' if account.privileged else None
+            inventory_hosts.append(h)
+        return inventory_hosts
+
+    def on_host_success(self, host, result):
+        account_info = self.name_recorder_mapper.get(host)
+        if not account_info:
+            return
+        account = account_info['account']
+        new_secret = account_info['new_secret']
+        if not account:
+            return
+        account.secret = new_secret
+        account.save(update_fields=['secret'])
+
+    def on_host_error(self, host, error, result):
+        pass
+
+    def on_runner_failed(self, runner, e):
+        logger.error("Pust account error: ", e)
+
+    def run(self, *args, **kwargs):
+        if not self.check_secret():
+            return
+        super().run(*args, **kwargs)
+
     # @classmethod
     # def trigger_by_asset_create(cls, asset):
     #     automations = PushAccountAutomation.objects.filter(
diff --git a/apps/accounts/automations/verify_account/manager.py b/apps/accounts/automations/verify_account/manager.py
index fc25794bf..2b6c831dc 100644
--- a/apps/accounts/automations/verify_account/manager.py
+++ b/apps/accounts/automations/verify_account/manager.py
@@ -2,12 +2,12 @@ from django.db.models import QuerySet
 
 from accounts.const import AutomationTypes, Connectivity
 from common.utils import get_logger
-from ..base.manager import PushOrVerifyHostCallbackMixin, AccountBasePlaybookManager
+from ..base.manager import VerifyHostCallbackMixin, AccountBasePlaybookManager
 
 logger = get_logger(__name__)
 
 
-class VerifyAccountManager(PushOrVerifyHostCallbackMixin, AccountBasePlaybookManager):
+class VerifyAccountManager(VerifyHostCallbackMixin, AccountBasePlaybookManager):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)