perf: 优化 playbook manager

pam
ibuler 1 week ago
parent a3b3254c35
commit 92f7209997

@ -38,8 +38,8 @@ def is_weak_password(password):
@bulk_create_decorator(AccountRisk)
def create_risk(account, risk):
pass
def create_risk(data):
return AccountRisk(**data)
@bulk_update_decorator(AccountRisk, update_fields=["details"])

@ -183,9 +183,15 @@ class GatherAccountsManager(AccountBasePlaybookManager):
print("Runner failed: ", e)
raise e
def on_host_error(self, host, error, result):
print(f'\033[31m {host} error: {error} \033[0m\n')
self.summary['error_assets'] += 1
def on_host_success(self, host, result):
info = self._get_nested_info(result, 'debug', 'res', 'info')
asset = self.host_asset_mapper.get(host)
self.summary['success_assets'] += 1
if asset and info:
self._collect_asset_account_info(asset, info)
else:

@ -66,7 +66,7 @@
<td>{{ forloop.counter }}</td>
<td>{{ account.asset }}</td>
<td>{{ account.username }}</td>
<td>{% trans 'Week password' %}</td>
<td style="color: red">{% trans 'Week password' %}</td>
</tr>
{% endfor %}
</tbody>

@ -25,48 +25,85 @@
<td>{% trans 'Time using' %}: </td>
<td>{{ execution.duration }}s</td>
</tr>
<tr>
<td>{% trans 'Assets count' %}: </td>
<td>{{ summary.assets }}</td>
</tr>
<tr>
<td>{% trans 'Account count' %}: </td>
<td>{{ summary.accounts }}</td>
<td>{% trans 'Asset success count' %}: </td>
<td>{{ summary.success_assets }}</td>
</tr>
<tr>
<td>{% trans 'Asset failed count' %}: </td>
<td>{{ summary.fail_assets }}</td>
</tr>
<tr>
<td>{% trans 'Asset not support count' %}: </td>
<td>{{ summary.na_assets }}</td>
</tr>
<tr>
<td>{% trans 'Week password count' %}:</td>
<td> <span> {{ summary.weak_password }}</span></td>
<td>{% trans 'Account new found count' %}: </td>
<td>{{ summary.new_accounts }}</td>
</tr>
<tr>
<td>{% trans 'Ok count' %}: </td>
<td>{{ summary.ok }}</td>
<td>{% trans 'Sudo changed count' %}:</td>
<td> <span> {{ summary.sudo_changed }}</span></td>
</tr>
<tr>
<td>{% trans 'No password count' %}: </td>
<td>{{ summary.no_secret }}</td>
<td>{% trans 'Groups changed count' %}:</td>
<td> <span> {{ summary.groups_changed }}</span></td>
</tr>
<tr>
<td>{% trans 'Authorized key changed count' %}:</td>
<td> <span> {{ summary.authorized_key_changed }}</span></td>
</tr>
</tbody>
</table>
</div>
<div class='result'>
<p>{% trans 'Account check details' %}:</p>
<p>{% trans 'New found accounts' %}:</p>
<table style="">
<thead>
<tr>
<th>{% trans 'No.' %}</th>
<th>{% trans 'Asset' %}</th>
<th>{% trans 'Username' %}</th>
<th>{% trans 'Result' %}</th>
</tr>
</thead>
<tbody>
{% for account in result.weak_password %}
{% for account in result.new_accounts %}
<tr>
<td>{{ forloop.counter }}</td>
<td>{{ account.asset }}</td>
<td>{{ account.username }}</td>
<td>{% trans 'Week password' %}</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<div class='result'>
<p>{% trans 'New found risk' %}:</p>
<table style="">
<thead>
<tr>
<th>{% trans 'No.' %}</th>
<th>{% trans 'Asset' %}</th>
<th>{% trans 'Username' %}</th>
<th>{% trans 'Result' %}</th>
</tr>
</thead>
<tbody>
{% for risk in result.risks %}
<tr>
<td>{{ forloop.counter }}</td>
<td>{{ risk.asset }}</td>
<td>{{ risk.username }}</td>
<td>{{ risk.risk }}</td>
</tr>
{% endfor %}
</tbody>

@ -30,46 +30,49 @@ class SSHTunnelManager:
@staticmethod
def file_to_json(path):
with open(path, 'r') as f:
with open(path, "r") as f:
d = json.load(f)
return d
@staticmethod
def json_to_file(path, data):
with open(path, 'w') as f:
with open(path, "w") as f:
json.dump(data, f, indent=4, sort_keys=True)
def local_gateway_prepare(self, runner):
info = self.file_to_json(runner.inventory)
servers, not_valid = [], []
for k, host in info['all']['hosts'].items():
jms_asset, jms_gateway = host.get('jms_asset'), host.get('jms_gateway')
for k, host in info["all"]["hosts"].items():
jms_asset, jms_gateway = host.get("jms_asset"), host.get("jms_gateway")
if not jms_gateway:
continue
try:
server = SSHTunnelForwarder(
(jms_gateway['address'], jms_gateway['port']),
ssh_username=jms_gateway['username'],
ssh_password=jms_gateway['secret'],
ssh_pkey=jms_gateway['private_key_path'],
remote_bind_address=(jms_asset['address'], jms_asset['port'])
(jms_gateway["address"], jms_gateway["port"]),
ssh_username=jms_gateway["username"],
ssh_password=jms_gateway["secret"],
ssh_pkey=jms_gateway["private_key_path"],
remote_bind_address=(jms_asset["address"], jms_asset["port"]),
)
server.start()
except Exception as e:
err_msg = 'Gateway is not active: %s' % jms_asset.get('name', '')
print(f'\033[31m {err_msg} 原因: {e} \033[0m\n')
err_msg = "Gateway is not active: %s" % jms_asset.get("name", "")
print(f"\033[31m {err_msg} 原因: {e} \033[0m\n")
not_valid.append(k)
else:
local_bind_port = server.local_bind_port
host['ansible_host'] = jms_asset['address'] = host[
'login_host'] = interface.get_gateway_proxy_host()
host['ansible_port'] = jms_asset['port'] = host['login_port'] = local_bind_port
host["ansible_host"] = jms_asset["address"] = host["login_host"] = (
interface.get_gateway_proxy_host()
)
host["ansible_port"] = jms_asset["port"] = host["login_port"] = (
local_bind_port
)
servers.append(server)
# 网域不可连接的,就不继续执行此资源的后续任务了
for a in set(not_valid):
info['all']['hosts'].pop(a)
info["all"]["hosts"].pop(a)
self.json_to_file(runner.inventory, info)
self.gateway_servers[runner.id] = servers
@ -100,7 +103,7 @@ class BaseManager:
def pre_run(self):
self.execution.date_start = timezone.now()
self.execution.save(update_fields=['date_start'])
self.execution.save(update_fields=["date_start"])
def update_execution(self):
self.duration = int(time.time() - self.time_start)
@ -108,7 +111,7 @@ class BaseManager:
self.execution.duration = self.duration
self.execution.summary = self.summary
self.execution.result = self.result
self.execution.status = 'success'
self.execution.status = "success"
with safe_db_connection():
self.execution.save()
@ -120,13 +123,13 @@ class BaseManager:
raise NotImplementedError
def get_report_subject(self):
return f'Automation {self.execution.id} finished'
return f"Automation {self.execution.id} finished"
def get_report_context(self):
return {
'execution': self.execution,
'summary': self.execution.summary,
'result': self.execution.result
"execution": self.execution,
"summary": self.execution.summary,
"result": self.execution.result,
}
def send_report_if_need(self):
@ -164,30 +167,44 @@ class BaseManager:
return json.dumps(data, indent=4, sort_keys=True)
class PlaybookUtil:
bulk_size = 100
ansible_account_policy = "privileged_first"
ansible_account_prefer = "root,Administrator"
def __init__(self, assets, playbook_dir, inventory_path):
self.assets = assets
self.playbook_dir = playbook_dir
self.inventory_path = inventory_path
class BasePlaybookManager(BaseManager):
bulk_size = 100
ansible_account_policy = 'privileged_first'
ansible_account_prefer = 'root,Administrator'
ansible_account_policy = "privileged_first"
ansible_account_prefer = "root,Administrator"
def __init__(self, execution):
super().__init__(execution)
# example: {'gather_fact_windows': {'id': 'gather_fact_windows', 'name': '', 'method': 'gather_fact', ...} }
self.method_id_meta_mapper = {
method['id']: method
method["id"]: method
for method in self.platform_automation_methods
if method['method'] == self.__class__.method_type()
if method["method"] == self.__class__.method_type()
}
# 根据执行方式就行分组, 不同资产的改密、推送等操作可能会使用不同的执行方式
# 然后根据执行方式分组, 再根据 bulk_size 分组, 生成不同的 playbook
self.playbooks = []
params = self.execution.snapshot.get('params')
params = self.execution.snapshot.get("params")
self.params = params or {}
def get_params(self, automation, method_type):
method_attr = '{}_method'.format(method_type)
method_params = '{}_params'.format(method_type)
method_attr = "{}_method".format(method_type)
method_params = "{}_params".format(method_type)
method_id = getattr(automation, method_attr)
automation_params = getattr(automation, method_params)
serializer = self.method_id_meta_mapper[method_id]['params_serializer']
serializer = self.method_id_meta_mapper[method_id]["params_serializer"]
if serializer is None:
return {}
@ -211,11 +228,14 @@ class BasePlaybookManager(BaseManager):
def prepare_runtime_dir(self):
ansible_dir = settings.ANSIBLE_DIR
task_name = self.execution.snapshot['name']
dir_name = '{}_{}'.format(task_name.replace(' ', '_'), self.execution.id)
task_name = self.execution.snapshot["name"]
dir_name = "{}_{}".format(task_name.replace(" ", "_"), self.execution.id)
path = os.path.join(
ansible_dir, 'automations', self.execution.snapshot['type'],
dir_name, timezone.now().strftime('%Y%m%d_%H%M%S')
ansible_dir,
"automations",
self.execution.snapshot["type"],
dir_name,
timezone.now().strftime("%Y%m%d_%H%M%S"),
)
if not os.path.exists(path):
os.makedirs(path, exist_ok=True, mode=0o755)
@ -225,13 +245,13 @@ class BasePlaybookManager(BaseManager):
def runtime_dir(self):
path = self.prepare_runtime_dir()
if settings.DEBUG_DEV:
msg = 'Ansible runtime dir: {}'.format(path)
msg = "Ansible runtime dir: {}".format(path)
print(msg)
return path
@staticmethod
def write_cert_to_file(filename, content):
with open(filename, 'w') as f:
with open(filename, "w") as f:
f.write(content)
return filename
@ -239,41 +259,28 @@ class BasePlaybookManager(BaseManager):
if not path_dir:
return host
specific = host.get('jms_asset', {}).get('secret_info', {})
cert_fields = ('ca_cert', 'client_key', 'client_cert')
specific = host.get("jms_asset", {}).get("secret_info", {})
cert_fields = ("ca_cert", "client_key", "client_cert")
filtered = list(filter(lambda x: specific.get(x), cert_fields))
if not filtered:
return host
cert_dir = os.path.join(path_dir, 'certs')
cert_dir = os.path.join(path_dir, "certs")
if not os.path.exists(cert_dir):
os.makedirs(cert_dir, 0o700, True)
for f in filtered:
result = self.write_cert_to_file(
os.path.join(cert_dir, f), specific.get(f)
)
host['jms_asset']['secret_info'][f] = result
result = self.write_cert_to_file(os.path.join(cert_dir, f), specific.get(f))
host["jms_asset"]["secret_info"][f] = result
return host
def on_host_method_not_enabled(self, host, **kwargs):
host["error"] = _("{} disabled".format(self.__class__.method_type()))
def host_callback(self, host, automation=None, **kwargs):
method_type = self.__class__.method_type()
enabled_attr = '{}_enabled'.format(method_type)
method_attr = '{}_method'.format(method_type)
method_enabled = (
automation
and getattr(automation, enabled_attr)
and getattr(automation, method_attr)
and getattr(automation, method_attr) in self.method_id_meta_mapper
)
if not method_enabled:
host['error'] = _('{} disabled'.format(self.__class__.method_type()))
return host
host = self.convert_cert_to_file(host, kwargs.get('path_dir'))
host['params'] = self.get_params(automation, method_type)
host = self.convert_cert_to_file(host, kwargs.get("path_dir"))
host["params"] = self.get_params(automation, method_type)
return host
@staticmethod
@ -282,16 +289,16 @@ class BasePlaybookManager(BaseManager):
@staticmethod
def generate_private_key_path(secret, path_dir):
key_name = '.' + hashlib.md5(secret.encode('utf-8')).hexdigest()
key_name = "." + hashlib.md5(secret.encode("utf-8")).hexdigest()
key_path = os.path.join(path_dir, key_name)
if not os.path.exists(key_path):
# https://github.com/ansible/ansible-runner/issues/544
# ssh requires OpenSSH format keys to have a full ending newline.
# It does not require this for old-style PEM keys.
with open(key_path, 'w') as f:
with open(key_path, "w") as f:
f.write(secret)
if is_openssh_format_key(secret.encode('utf-8')):
if is_openssh_format_key(secret.encode("utf-8")):
f.write("\n")
os.chmod(key_path, 0o400)
return key_path
@ -309,50 +316,87 @@ class BasePlaybookManager(BaseManager):
@staticmethod
def generate_playbook(method, sub_playbook_dir):
method_playbook_dir_path = method['dir']
sub_playbook_path = os.path.join(sub_playbook_dir, 'project', 'main.yml')
method_playbook_dir_path = method["dir"]
sub_playbook_path = os.path.join(sub_playbook_dir, "project", "main.yml")
shutil.copytree(method_playbook_dir_path, os.path.dirname(sub_playbook_path))
with open(sub_playbook_path, 'r') as f:
with open(sub_playbook_path, "r") as f:
plays = yaml.safe_load(f)
for play in plays:
play['hosts'] = 'all'
play["hosts"] = "all"
with open(sub_playbook_path, 'w') as f:
with open(sub_playbook_path, "w") as f:
yaml.safe_dump(plays, f)
return sub_playbook_path
def on_assets_not_ansible_enabled(self, assets):
for asset in assets:
print("\t{}".format(asset))
def on_assets_not_method_enabled(self, assets, method_id):
for asset in assets:
print("\t{}".format(asset))
def on_playbook_not_found(self, assets):
pass
def check_automation_enabled(self, platform, assets):
if not platform.automation or not platform.automation.ansible_enabled:
print(_(" - Platform {} ansible disabled").format(platform.name))
self.on_assets_not_ansible_enabled(assets)
automation = platform.automation
method_type = self.__class__.method_type()
enabled_attr = "{}_enabled".format(method_type)
method_attr = "{}_method".format(method_type)
method_enabled = (
automation
and getattr(automation, enabled_attr)
and getattr(automation, method_attr)
and getattr(automation, method_attr) in self.method_id_meta_mapper
)
if not method_enabled:
self.on_assets_not_method_enabled(assets, method_type)
def get_runners(self):
assets_group_by_platform = self.get_assets_group_by_platform()
if settings.DEBUG_DEV:
msg = 'Assets group by platform: {}'.format(dict(assets_group_by_platform))
msg = "Assets group by platform: {}".format(dict(assets_group_by_platform))
print(msg)
runners = []
for platform, assets in assets_group_by_platform.items():
if not assets:
continue
if not platform.automation or not platform.automation.ansible_enabled:
print(_(" - Platform {} ansible disabled").format(platform.name))
if not self.check_automation_enabled(platform, assets):
continue
assets_bulked = [assets[i:i + self.bulk_size] for i in range(0, len(assets), self.bulk_size)]
# 避免一个任务太大,分批执行
assets_bulked = [
assets[i : i + self.bulk_size]
for i in range(0, len(assets), self.bulk_size)
]
for i, _assets in enumerate(assets_bulked, start=1):
sub_dir = '{}_{}'.format(platform.name, i)
sub_dir = "{}_{}".format(platform.name, i)
playbook_dir = os.path.join(self.runtime_dir, sub_dir)
inventory_path = os.path.join(self.runtime_dir, sub_dir, 'hosts.json')
inventory_path = os.path.join(self.runtime_dir, sub_dir, "hosts.json")
method_id = getattr(platform.automation, '{}_method'.format(self.__class__.method_type()))
# method_id = getattr(
# platform.automation,
# "{}_method".format(self.__class__.method_type()),
# )
method = self.method_id_meta_mapper.get(method_id)
if not method:
logger.error("Method not found: {}".format(method_id))
continue
protocol = method.get('protocol')
protocol = method.get("protocol")
self.generate_inventory(_assets, inventory_path, protocol)
playbook_path = self.generate_playbook(method, playbook_dir)
if not playbook_path:
self.on_playbook_not_found(_assets)
continue
runer = SuperPlaybookRunner(
@ -362,9 +406,9 @@ class BasePlaybookManager(BaseManager):
callback=PlaybookCallback(),
)
with open(inventory_path, 'r') as f:
with open(inventory_path, "r") as f:
inventory_data = json.load(f)
if not inventory_data['all'].get('hosts'):
if not inventory_data["all"].get("hosts"):
continue
runners.append(runer)
@ -375,14 +419,14 @@ class BasePlaybookManager(BaseManager):
def on_host_error(self, host, error, result):
if settings.DEBUG_DEV:
print('host error: {} -> {}'.format(host, error))
print("host error: {} -> {}".format(host, error))
def _on_host_success(self, host, result, hosts):
self.on_host_success(host, result.get("ok", ''))
self.on_host_success(host, result.get("ok", ""))
def _on_host_error(self, host, result, hosts):
error = hosts.get(host, '')
detail = result.get('failures', '') or result.get('dark', '')
error = hosts.get(host, "")
detail = result.get("failures", "") or result.get("dark", "")
self.on_host_error(host, error, detail)
def on_runner_success(self, runner, cb):
@ -390,9 +434,9 @@ class BasePlaybookManager(BaseManager):
for state, hosts in summary.items():
# 错误行为为host 是 dict ok 时是 list
if state == 'ok':
if state == "ok":
handler = self._on_host_success
elif state == 'skipped':
elif state == "skipped":
continue
else:
handler = self._on_host_error
@ -413,7 +457,11 @@ class BasePlaybookManager(BaseManager):
print(_(">>> Task preparation phase"), end="\n")
runners = self.get_runners()
if len(runners) > 1:
print(_(">>> Executing tasks in batches, total {runner_count}").format(runner_count=len(runners)))
print(
_(">>> Executing tasks in batches, total {runner_count}").format(
runner_count=len(runners)
)
)
elif len(runners) == 1:
print(_(">>> Start executing tasks"))
else:
@ -433,6 +481,4 @@ class BasePlaybookManager(BaseManager):
self.on_runner_failed(runner, e)
finally:
ssh_tunnel.local_gateway_clean(runner)
print('\n')
print("\n")

@ -74,6 +74,8 @@ def sorted_methods(methods):
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
platform_automation_methods = get_platform_automation_methods(BASE_DIR)
print("platform_automation_methods: ")
print(json.dumps(platform_automation_methods, indent=4))
if __name__ == '__main__':
print(json.dumps(platform_automation_methods, indent=4))

Loading…
Cancel
Save