perf: 优化 playbook manager

pull/14534/head
ibuler 2024-11-20 19:12:28 +08:00
parent a3b3254c35
commit 92f7209997
6 changed files with 197 additions and 106 deletions

View File

@ -38,8 +38,8 @@ def is_weak_password(password):
@bulk_create_decorator(AccountRisk) @bulk_create_decorator(AccountRisk)
def create_risk(account, risk): def create_risk(data):
pass return AccountRisk(**data)
@bulk_update_decorator(AccountRisk, update_fields=["details"]) @bulk_update_decorator(AccountRisk, update_fields=["details"])

View File

@ -183,9 +183,15 @@ class GatherAccountsManager(AccountBasePlaybookManager):
print("Runner failed: ", e) print("Runner failed: ", e)
raise e raise e
def on_host_error(self, host, error, result):
print(f'\033[31m {host} error: {error} \033[0m\n')
self.summary['error_assets'] += 1
def on_host_success(self, host, result): def on_host_success(self, host, result):
info = self._get_nested_info(result, 'debug', 'res', 'info') info = self._get_nested_info(result, 'debug', 'res', 'info')
asset = self.host_asset_mapper.get(host) asset = self.host_asset_mapper.get(host)
self.summary['success_assets'] += 1
if asset and info: if asset and info:
self._collect_asset_account_info(asset, info) self._collect_asset_account_info(asset, info)
else: else:

View File

@ -66,7 +66,7 @@
<td>{{ forloop.counter }}</td> <td>{{ forloop.counter }}</td>
<td>{{ account.asset }}</td> <td>{{ account.asset }}</td>
<td>{{ account.username }}</td> <td>{{ account.username }}</td>
<td>{% trans 'Week password' %}</td> <td style="color: red">{% trans 'Week password' %}</td>
</tr> </tr>
{% endfor %} {% endfor %}
</tbody> </tbody>

View File

@ -25,32 +25,69 @@
<td>{% trans 'Time using' %}: </td> <td>{% trans 'Time using' %}: </td>
<td>{{ execution.duration }}s</td> <td>{{ execution.duration }}s</td>
</tr> </tr>
<tr> <tr>
<td>{% trans 'Assets count' %}: </td> <td>{% trans 'Assets count' %}: </td>
<td>{{ summary.assets }}</td> <td>{{ summary.assets }}</td>
</tr> </tr>
<tr> <tr>
<td>{% trans 'Account count' %}: </td> <td>{% trans 'Asset success count' %}: </td>
<td>{{ summary.accounts }}</td> <td>{{ summary.success_assets }}</td>
</tr> </tr>
<tr> <tr>
<td>{% trans 'Week password count' %}:</td> <td>{% trans 'Asset failed count' %}: </td>
<td> <span> {{ summary.weak_password }}</span></td> <td>{{ summary.fail_assets }}</td>
</tr> </tr>
<tr> <tr>
<td>{% trans 'Ok count' %}: </td> <td>{% trans 'Asset not support count' %}: </td>
<td>{{ summary.ok }}</td> <td>{{ summary.na_assets }}</td>
</tr>
<tr>
<td>{% trans 'Account new found count' %}: </td>
<td>{{ summary.new_accounts }}</td>
</tr>
<tr>
<td>{% trans 'Sudo changed count' %}:</td>
<td> <span> {{ summary.sudo_changed }}</span></td>
</tr> </tr>
<tr> <tr>
<td>{% trans 'No password count' %}: </td> <td>{% trans 'Groups changed count' %}:</td>
<td>{{ summary.no_secret }}</td> <td> <span> {{ summary.groups_changed }}</span></td>
</tr>
<tr>
<td>{% trans 'Authorized key changed count' %}:</td>
<td> <span> {{ summary.authorized_key_changed }}</span></td>
</tr> </tr>
</tbody> </tbody>
</table> </table>
</div> </div>
<div class='result'> <div class='result'>
<p>{% trans 'Account check details' %}:</p> <p>{% trans 'New found accounts' %}:</p>
<table style="">
<thead>
<tr>
<th>{% trans 'No.' %}</th>
<th>{% trans 'Asset' %}</th>
<th>{% trans 'Username' %}</th>
</tr>
</thead>
<tbody>
{% for account in result.new_accounts %}
<tr>
<td>{{ forloop.counter }}</td>
<td>{{ account.asset }}</td>
<td>{{ account.username }}</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<div class='result'>
<p>{% trans 'New found risk' %}:</p>
<table style=""> <table style="">
<thead> <thead>
<tr> <tr>
@ -61,12 +98,12 @@
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{% for account in result.weak_password %} {% for risk in result.risks %}
<tr> <tr>
<td>{{ forloop.counter }}</td> <td>{{ forloop.counter }}</td>
<td>{{ account.asset }}</td> <td>{{ risk.asset }}</td>
<td>{{ account.username }}</td> <td>{{ risk.username }}</td>
<td>{% trans 'Week password' %}</td> <td>{{ risk.risk }}</td>
</tr> </tr>
{% endfor %} {% endfor %}
</tbody> </tbody>

View File

@ -30,46 +30,49 @@ class SSHTunnelManager:
@staticmethod @staticmethod
def file_to_json(path): def file_to_json(path):
with open(path, 'r') as f: with open(path, "r") as f:
d = json.load(f) d = json.load(f)
return d return d
@staticmethod @staticmethod
def json_to_file(path, data): def json_to_file(path, data):
with open(path, 'w') as f: with open(path, "w") as f:
json.dump(data, f, indent=4, sort_keys=True) json.dump(data, f, indent=4, sort_keys=True)
def local_gateway_prepare(self, runner): def local_gateway_prepare(self, runner):
info = self.file_to_json(runner.inventory) info = self.file_to_json(runner.inventory)
servers, not_valid = [], [] servers, not_valid = [], []
for k, host in info['all']['hosts'].items(): for k, host in info["all"]["hosts"].items():
jms_asset, jms_gateway = host.get('jms_asset'), host.get('jms_gateway') jms_asset, jms_gateway = host.get("jms_asset"), host.get("jms_gateway")
if not jms_gateway: if not jms_gateway:
continue continue
try: try:
server = SSHTunnelForwarder( server = SSHTunnelForwarder(
(jms_gateway['address'], jms_gateway['port']), (jms_gateway["address"], jms_gateway["port"]),
ssh_username=jms_gateway['username'], ssh_username=jms_gateway["username"],
ssh_password=jms_gateway['secret'], ssh_password=jms_gateway["secret"],
ssh_pkey=jms_gateway['private_key_path'], ssh_pkey=jms_gateway["private_key_path"],
remote_bind_address=(jms_asset['address'], jms_asset['port']) remote_bind_address=(jms_asset["address"], jms_asset["port"]),
) )
server.start() server.start()
except Exception as e: except Exception as e:
err_msg = 'Gateway is not active: %s' % jms_asset.get('name', '') err_msg = "Gateway is not active: %s" % jms_asset.get("name", "")
print(f'\033[31m {err_msg} 原因: {e} \033[0m\n') print(f"\033[31m {err_msg} 原因: {e} \033[0m\n")
not_valid.append(k) not_valid.append(k)
else: else:
local_bind_port = server.local_bind_port local_bind_port = server.local_bind_port
host['ansible_host'] = jms_asset['address'] = host[ host["ansible_host"] = jms_asset["address"] = host["login_host"] = (
'login_host'] = interface.get_gateway_proxy_host() interface.get_gateway_proxy_host()
host['ansible_port'] = jms_asset['port'] = host['login_port'] = local_bind_port )
host["ansible_port"] = jms_asset["port"] = host["login_port"] = (
local_bind_port
)
servers.append(server) servers.append(server)
# 网域不可连接的,就不继续执行此资源的后续任务了 # 网域不可连接的,就不继续执行此资源的后续任务了
for a in set(not_valid): for a in set(not_valid):
info['all']['hosts'].pop(a) info["all"]["hosts"].pop(a)
self.json_to_file(runner.inventory, info) self.json_to_file(runner.inventory, info)
self.gateway_servers[runner.id] = servers self.gateway_servers[runner.id] = servers
@ -100,7 +103,7 @@ class BaseManager:
def pre_run(self): def pre_run(self):
self.execution.date_start = timezone.now() self.execution.date_start = timezone.now()
self.execution.save(update_fields=['date_start']) self.execution.save(update_fields=["date_start"])
def update_execution(self): def update_execution(self):
self.duration = int(time.time() - self.time_start) self.duration = int(time.time() - self.time_start)
@ -108,7 +111,7 @@ class BaseManager:
self.execution.duration = self.duration self.execution.duration = self.duration
self.execution.summary = self.summary self.execution.summary = self.summary
self.execution.result = self.result self.execution.result = self.result
self.execution.status = 'success' self.execution.status = "success"
with safe_db_connection(): with safe_db_connection():
self.execution.save() self.execution.save()
@ -120,13 +123,13 @@ class BaseManager:
raise NotImplementedError raise NotImplementedError
def get_report_subject(self): def get_report_subject(self):
return f'Automation {self.execution.id} finished' return f"Automation {self.execution.id} finished"
def get_report_context(self): def get_report_context(self):
return { return {
'execution': self.execution, "execution": self.execution,
'summary': self.execution.summary, "summary": self.execution.summary,
'result': self.execution.result "result": self.execution.result,
} }
def send_report_if_need(self): def send_report_if_need(self):
@ -164,30 +167,44 @@ class BaseManager:
return json.dumps(data, indent=4, sort_keys=True) return json.dumps(data, indent=4, sort_keys=True)
class PlaybookUtil:
bulk_size = 100
ansible_account_policy = "privileged_first"
ansible_account_prefer = "root,Administrator"
def __init__(self, assets, playbook_dir, inventory_path):
self.assets = assets
self.playbook_dir = playbook_dir
self.inventory_path = inventory_path
class BasePlaybookManager(BaseManager): class BasePlaybookManager(BaseManager):
bulk_size = 100 bulk_size = 100
ansible_account_policy = 'privileged_first' ansible_account_policy = "privileged_first"
ansible_account_prefer = 'root,Administrator' ansible_account_prefer = "root,Administrator"
def __init__(self, execution): def __init__(self, execution):
super().__init__(execution) super().__init__(execution)
# example: {'gather_fact_windows': {'id': 'gather_fact_windows', 'name': '', 'method': 'gather_fact', ...} }
self.method_id_meta_mapper = { self.method_id_meta_mapper = {
method['id']: method method["id"]: method
for method in self.platform_automation_methods for method in self.platform_automation_methods
if method['method'] == self.__class__.method_type() if method["method"] == self.__class__.method_type()
} }
# 根据执行方式就行分组, 不同资产的改密、推送等操作可能会使用不同的执行方式 # 根据执行方式就行分组, 不同资产的改密、推送等操作可能会使用不同的执行方式
# 然后根据执行方式分组, 再根据 bulk_size 分组, 生成不同的 playbook # 然后根据执行方式分组, 再根据 bulk_size 分组, 生成不同的 playbook
self.playbooks = [] self.playbooks = []
params = self.execution.snapshot.get('params') params = self.execution.snapshot.get("params")
self.params = params or {} self.params = params or {}
def get_params(self, automation, method_type): def get_params(self, automation, method_type):
method_attr = '{}_method'.format(method_type) method_attr = "{}_method".format(method_type)
method_params = '{}_params'.format(method_type) method_params = "{}_params".format(method_type)
method_id = getattr(automation, method_attr) method_id = getattr(automation, method_attr)
automation_params = getattr(automation, method_params) automation_params = getattr(automation, method_params)
serializer = self.method_id_meta_mapper[method_id]['params_serializer'] serializer = self.method_id_meta_mapper[method_id]["params_serializer"]
if serializer is None: if serializer is None:
return {} return {}
@ -211,11 +228,14 @@ class BasePlaybookManager(BaseManager):
def prepare_runtime_dir(self): def prepare_runtime_dir(self):
ansible_dir = settings.ANSIBLE_DIR ansible_dir = settings.ANSIBLE_DIR
task_name = self.execution.snapshot['name'] task_name = self.execution.snapshot["name"]
dir_name = '{}_{}'.format(task_name.replace(' ', '_'), self.execution.id) dir_name = "{}_{}".format(task_name.replace(" ", "_"), self.execution.id)
path = os.path.join( path = os.path.join(
ansible_dir, 'automations', self.execution.snapshot['type'], ansible_dir,
dir_name, timezone.now().strftime('%Y%m%d_%H%M%S') "automations",
self.execution.snapshot["type"],
dir_name,
timezone.now().strftime("%Y%m%d_%H%M%S"),
) )
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path, exist_ok=True, mode=0o755) os.makedirs(path, exist_ok=True, mode=0o755)
@ -225,13 +245,13 @@ class BasePlaybookManager(BaseManager):
def runtime_dir(self): def runtime_dir(self):
path = self.prepare_runtime_dir() path = self.prepare_runtime_dir()
if settings.DEBUG_DEV: if settings.DEBUG_DEV:
msg = 'Ansible runtime dir: {}'.format(path) msg = "Ansible runtime dir: {}".format(path)
print(msg) print(msg)
return path return path
@staticmethod @staticmethod
def write_cert_to_file(filename, content): def write_cert_to_file(filename, content):
with open(filename, 'w') as f: with open(filename, "w") as f:
f.write(content) f.write(content)
return filename return filename
@ -239,41 +259,28 @@ class BasePlaybookManager(BaseManager):
if not path_dir: if not path_dir:
return host return host
specific = host.get('jms_asset', {}).get('secret_info', {}) specific = host.get("jms_asset", {}).get("secret_info", {})
cert_fields = ('ca_cert', 'client_key', 'client_cert') cert_fields = ("ca_cert", "client_key", "client_cert")
filtered = list(filter(lambda x: specific.get(x), cert_fields)) filtered = list(filter(lambda x: specific.get(x), cert_fields))
if not filtered: if not filtered:
return host return host
cert_dir = os.path.join(path_dir, 'certs') cert_dir = os.path.join(path_dir, "certs")
if not os.path.exists(cert_dir): if not os.path.exists(cert_dir):
os.makedirs(cert_dir, 0o700, True) os.makedirs(cert_dir, 0o700, True)
for f in filtered: for f in filtered:
result = self.write_cert_to_file( result = self.write_cert_to_file(os.path.join(cert_dir, f), specific.get(f))
os.path.join(cert_dir, f), specific.get(f) host["jms_asset"]["secret_info"][f] = result
)
host['jms_asset']['secret_info'][f] = result
return host return host
def on_host_method_not_enabled(self, host, **kwargs):
host["error"] = _("{} disabled".format(self.__class__.method_type()))
def host_callback(self, host, automation=None, **kwargs): def host_callback(self, host, automation=None, **kwargs):
method_type = self.__class__.method_type() method_type = self.__class__.method_type()
enabled_attr = '{}_enabled'.format(method_type) host = self.convert_cert_to_file(host, kwargs.get("path_dir"))
method_attr = '{}_method'.format(method_type) host["params"] = self.get_params(automation, method_type)
method_enabled = (
automation
and getattr(automation, enabled_attr)
and getattr(automation, method_attr)
and getattr(automation, method_attr) in self.method_id_meta_mapper
)
if not method_enabled:
host['error'] = _('{} disabled'.format(self.__class__.method_type()))
return host
host = self.convert_cert_to_file(host, kwargs.get('path_dir'))
host['params'] = self.get_params(automation, method_type)
return host return host
@staticmethod @staticmethod
@ -282,16 +289,16 @@ class BasePlaybookManager(BaseManager):
@staticmethod @staticmethod
def generate_private_key_path(secret, path_dir): def generate_private_key_path(secret, path_dir):
key_name = '.' + hashlib.md5(secret.encode('utf-8')).hexdigest() key_name = "." + hashlib.md5(secret.encode("utf-8")).hexdigest()
key_path = os.path.join(path_dir, key_name) key_path = os.path.join(path_dir, key_name)
if not os.path.exists(key_path): if not os.path.exists(key_path):
# https://github.com/ansible/ansible-runner/issues/544 # https://github.com/ansible/ansible-runner/issues/544
# ssh requires OpenSSH format keys to have a full ending newline. # ssh requires OpenSSH format keys to have a full ending newline.
# It does not require this for old-style PEM keys. # It does not require this for old-style PEM keys.
with open(key_path, 'w') as f: with open(key_path, "w") as f:
f.write(secret) f.write(secret)
if is_openssh_format_key(secret.encode('utf-8')): if is_openssh_format_key(secret.encode("utf-8")):
f.write("\n") f.write("\n")
os.chmod(key_path, 0o400) os.chmod(key_path, 0o400)
return key_path return key_path
@ -309,50 +316,87 @@ class BasePlaybookManager(BaseManager):
@staticmethod @staticmethod
def generate_playbook(method, sub_playbook_dir): def generate_playbook(method, sub_playbook_dir):
method_playbook_dir_path = method['dir'] method_playbook_dir_path = method["dir"]
sub_playbook_path = os.path.join(sub_playbook_dir, 'project', 'main.yml') sub_playbook_path = os.path.join(sub_playbook_dir, "project", "main.yml")
shutil.copytree(method_playbook_dir_path, os.path.dirname(sub_playbook_path)) shutil.copytree(method_playbook_dir_path, os.path.dirname(sub_playbook_path))
with open(sub_playbook_path, 'r') as f: with open(sub_playbook_path, "r") as f:
plays = yaml.safe_load(f) plays = yaml.safe_load(f)
for play in plays: for play in plays:
play['hosts'] = 'all' play["hosts"] = "all"
with open(sub_playbook_path, 'w') as f: with open(sub_playbook_path, "w") as f:
yaml.safe_dump(plays, f) yaml.safe_dump(plays, f)
return sub_playbook_path return sub_playbook_path
def on_assets_not_ansible_enabled(self, assets):
for asset in assets:
print("\t{}".format(asset))
def on_assets_not_method_enabled(self, assets, method_id):
for asset in assets:
print("\t{}".format(asset))
def on_playbook_not_found(self, assets):
pass
def check_automation_enabled(self, platform, assets):
if not platform.automation or not platform.automation.ansible_enabled:
print(_(" - Platform {} ansible disabled").format(platform.name))
self.on_assets_not_ansible_enabled(assets)
automation = platform.automation
method_type = self.__class__.method_type()
enabled_attr = "{}_enabled".format(method_type)
method_attr = "{}_method".format(method_type)
method_enabled = (
automation
and getattr(automation, enabled_attr)
and getattr(automation, method_attr)
and getattr(automation, method_attr) in self.method_id_meta_mapper
)
if not method_enabled:
self.on_assets_not_method_enabled(assets, method_type)
def get_runners(self): def get_runners(self):
assets_group_by_platform = self.get_assets_group_by_platform() assets_group_by_platform = self.get_assets_group_by_platform()
if settings.DEBUG_DEV: if settings.DEBUG_DEV:
msg = 'Assets group by platform: {}'.format(dict(assets_group_by_platform)) msg = "Assets group by platform: {}".format(dict(assets_group_by_platform))
print(msg) print(msg)
runners = [] runners = []
for platform, assets in assets_group_by_platform.items(): for platform, assets in assets_group_by_platform.items():
if not assets: if not assets:
continue continue
if not platform.automation or not platform.automation.ansible_enabled:
print(_(" - Platform {} ansible disabled").format(platform.name)) if not self.check_automation_enabled(platform, assets):
continue continue
assets_bulked = [assets[i:i + self.bulk_size] for i in range(0, len(assets), self.bulk_size)] # 避免一个任务太大,分批执行
assets_bulked = [
assets[i : i + self.bulk_size]
for i in range(0, len(assets), self.bulk_size)
]
for i, _assets in enumerate(assets_bulked, start=1): for i, _assets in enumerate(assets_bulked, start=1):
sub_dir = '{}_{}'.format(platform.name, i) sub_dir = "{}_{}".format(platform.name, i)
playbook_dir = os.path.join(self.runtime_dir, sub_dir) playbook_dir = os.path.join(self.runtime_dir, sub_dir)
inventory_path = os.path.join(self.runtime_dir, sub_dir, 'hosts.json') inventory_path = os.path.join(self.runtime_dir, sub_dir, "hosts.json")
method_id = getattr(platform.automation, '{}_method'.format(self.__class__.method_type())) # method_id = getattr(
# platform.automation,
# "{}_method".format(self.__class__.method_type()),
# )
method = self.method_id_meta_mapper.get(method_id) method = self.method_id_meta_mapper.get(method_id)
if not method: protocol = method.get("protocol")
logger.error("Method not found: {}".format(method_id))
continue
protocol = method.get('protocol')
self.generate_inventory(_assets, inventory_path, protocol) self.generate_inventory(_assets, inventory_path, protocol)
playbook_path = self.generate_playbook(method, playbook_dir) playbook_path = self.generate_playbook(method, playbook_dir)
if not playbook_path: if not playbook_path:
self.on_playbook_not_found(_assets)
continue continue
runer = SuperPlaybookRunner( runer = SuperPlaybookRunner(
@ -362,9 +406,9 @@ class BasePlaybookManager(BaseManager):
callback=PlaybookCallback(), callback=PlaybookCallback(),
) )
with open(inventory_path, 'r') as f: with open(inventory_path, "r") as f:
inventory_data = json.load(f) inventory_data = json.load(f)
if not inventory_data['all'].get('hosts'): if not inventory_data["all"].get("hosts"):
continue continue
runners.append(runer) runners.append(runer)
@ -375,14 +419,14 @@ class BasePlaybookManager(BaseManager):
def on_host_error(self, host, error, result): def on_host_error(self, host, error, result):
if settings.DEBUG_DEV: if settings.DEBUG_DEV:
print('host error: {} -> {}'.format(host, error)) print("host error: {} -> {}".format(host, error))
def _on_host_success(self, host, result, hosts): def _on_host_success(self, host, result, hosts):
self.on_host_success(host, result.get("ok", '')) self.on_host_success(host, result.get("ok", ""))
def _on_host_error(self, host, result, hosts): def _on_host_error(self, host, result, hosts):
error = hosts.get(host, '') error = hosts.get(host, "")
detail = result.get('failures', '') or result.get('dark', '') detail = result.get("failures", "") or result.get("dark", "")
self.on_host_error(host, error, detail) self.on_host_error(host, error, detail)
def on_runner_success(self, runner, cb): def on_runner_success(self, runner, cb):
@ -390,9 +434,9 @@ class BasePlaybookManager(BaseManager):
for state, hosts in summary.items(): for state, hosts in summary.items():
# 错误行为为host 是 dict ok 时是 list # 错误行为为host 是 dict ok 时是 list
if state == 'ok': if state == "ok":
handler = self._on_host_success handler = self._on_host_success
elif state == 'skipped': elif state == "skipped":
continue continue
else: else:
handler = self._on_host_error handler = self._on_host_error
@ -413,7 +457,11 @@ class BasePlaybookManager(BaseManager):
print(_(">>> Task preparation phase"), end="\n") print(_(">>> Task preparation phase"), end="\n")
runners = self.get_runners() runners = self.get_runners()
if len(runners) > 1: if len(runners) > 1:
print(_(">>> Executing tasks in batches, total {runner_count}").format(runner_count=len(runners))) print(
_(">>> Executing tasks in batches, total {runner_count}").format(
runner_count=len(runners)
)
)
elif len(runners) == 1: elif len(runners) == 1:
print(_(">>> Start executing tasks")) print(_(">>> Start executing tasks"))
else: else:
@ -433,6 +481,4 @@ class BasePlaybookManager(BaseManager):
self.on_runner_failed(runner, e) self.on_runner_failed(runner, e)
finally: finally:
ssh_tunnel.local_gateway_clean(runner) ssh_tunnel.local_gateway_clean(runner)
print('\n') print("\n")

View File

@ -74,6 +74,8 @@ def sorted_methods(methods):
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(os.path.abspath(__file__))
platform_automation_methods = get_platform_automation_methods(BASE_DIR) platform_automation_methods = get_platform_automation_methods(BASE_DIR)
print("platform_automation_methods: ")
print(json.dumps(platform_automation_methods, indent=4))
if __name__ == '__main__': if __name__ == '__main__':
print(json.dumps(platform_automation_methods, indent=4)) print(json.dumps(platform_automation_methods, indent=4))