mirror of https://github.com/jumpserver/jumpserver
439 lines
15 KiB
Python
439 lines
15 KiB
Python
import hashlib
|
||
import json
|
||
import os
|
||
import shutil
|
||
import time
|
||
from collections import defaultdict
|
||
from socket import gethostname
|
||
|
||
import yaml
|
||
from django.conf import settings
|
||
from django.template.loader import render_to_string
|
||
from django.utils import timezone
|
||
from django.utils.translation import gettext as _
|
||
from premailer import transform
|
||
from sshtunnel import SSHTunnelForwarder
|
||
|
||
from assets.automations.methods import platform_automation_methods
|
||
from common.db.utils import safe_db_connection
|
||
from common.tasks import send_mail_async
|
||
from common.utils import get_logger, lazyproperty, is_openssh_format_key, ssh_pubkey_gen
|
||
from ops.ansible import JMSInventory, DefaultCallback, SuperPlaybookRunner
|
||
from ops.ansible.interface import interface
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class SSHTunnelManager:
|
||
def __init__(self, *args, **kwargs):
|
||
self.gateway_servers = dict()
|
||
|
||
@staticmethod
|
||
def file_to_json(path):
|
||
with open(path, 'r') as f:
|
||
d = json.load(f)
|
||
return d
|
||
|
||
@staticmethod
|
||
def json_to_file(path, data):
|
||
with open(path, 'w') as f:
|
||
json.dump(data, f, indent=4, sort_keys=True)
|
||
|
||
def local_gateway_prepare(self, runner):
|
||
info = self.file_to_json(runner.inventory)
|
||
servers, not_valid = [], []
|
||
for k, host in info['all']['hosts'].items():
|
||
jms_asset, jms_gateway = host.get('jms_asset'), host.get('jms_gateway')
|
||
if not jms_gateway:
|
||
continue
|
||
try:
|
||
server = SSHTunnelForwarder(
|
||
(jms_gateway['address'], jms_gateway['port']),
|
||
ssh_username=jms_gateway['username'],
|
||
ssh_password=jms_gateway['secret'],
|
||
ssh_pkey=jms_gateway['private_key_path'],
|
||
remote_bind_address=(jms_asset['address'], jms_asset['port'])
|
||
)
|
||
server.start()
|
||
except Exception as e:
|
||
err_msg = 'Gateway is not active: %s' % jms_asset.get('name', '')
|
||
print(f'\033[31m {err_msg} 原因: {e} \033[0m\n')
|
||
not_valid.append(k)
|
||
else:
|
||
local_bind_port = server.local_bind_port
|
||
|
||
host['ansible_host'] = jms_asset['address'] = host[
|
||
'login_host'] = interface.get_gateway_proxy_host()
|
||
host['ansible_port'] = jms_asset['port'] = host['login_port'] = local_bind_port
|
||
servers.append(server)
|
||
|
||
# 网域不可连接的,就不继续执行此资源的后续任务了
|
||
for a in set(not_valid):
|
||
info['all']['hosts'].pop(a)
|
||
self.json_to_file(runner.inventory, info)
|
||
self.gateway_servers[runner.id] = servers
|
||
|
||
def local_gateway_clean(self, runner):
|
||
servers = self.gateway_servers.get(runner.id, [])
|
||
for s in servers:
|
||
try:
|
||
s.stop()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
class PlaybookCallback(DefaultCallback):
|
||
def playbook_on_stats(self, event_data, **kwargs):
|
||
super().playbook_on_stats(event_data, **kwargs)
|
||
|
||
|
||
class BaseManager:
|
||
def __init__(self, execution):
|
||
self.execution = execution
|
||
self.time_start = time.time()
|
||
self.summary = defaultdict(int)
|
||
self.result = defaultdict(list)
|
||
self.duration = 0
|
||
|
||
def get_assets_group_by_platform(self):
|
||
return self.execution.all_assets_group_by_platform()
|
||
|
||
def pre_run(self):
|
||
self.execution.date_start = timezone.now()
|
||
self.execution.save(update_fields=['date_start'])
|
||
|
||
def update_execution(self):
|
||
self.duration = int(time.time() - self.time_start)
|
||
self.execution.date_finished = timezone.now()
|
||
self.execution.duration = self.duration
|
||
self.execution.summary = self.summary
|
||
self.execution.result = self.result
|
||
self.execution.status = 'success'
|
||
|
||
with safe_db_connection():
|
||
self.execution.save()
|
||
|
||
def print_summary(self):
|
||
pass
|
||
|
||
def get_report_template(self):
|
||
raise NotImplementedError
|
||
|
||
def get_report_subject(self):
|
||
return f'Automation {self.execution.id} finished'
|
||
|
||
def get_report_context(self):
|
||
return {
|
||
'execution': self.execution,
|
||
'summary': self.execution.summary,
|
||
'result': self.execution.result
|
||
}
|
||
|
||
def send_report_if_need(self):
|
||
recipients = self.execution.recipients
|
||
if not recipients:
|
||
return
|
||
|
||
report = self.gen_report()
|
||
report = transform(report)
|
||
subject = self.get_report_subject()
|
||
emails = [r.email for r in recipients if r.email]
|
||
send_mail_async(subject, report, emails, html_message=report)
|
||
|
||
def gen_report(self):
|
||
template_path = self.get_report_template()
|
||
context = self.get_report_context()
|
||
data = render_to_string(template_path, context)
|
||
return data
|
||
|
||
def post_run(self):
|
||
self.update_execution()
|
||
self.print_summary()
|
||
self.send_report_if_need()
|
||
|
||
def run(self, *args, **kwargs):
|
||
self.pre_run()
|
||
self.do_run(*args, **kwargs)
|
||
self.post_run()
|
||
|
||
def do_run(self, *args, **kwargs):
|
||
raise NotImplementedError
|
||
|
||
@staticmethod
|
||
def json_dumps(data):
|
||
return json.dumps(data, indent=4, sort_keys=True)
|
||
|
||
|
||
class BasePlaybookManager(BaseManager):
|
||
bulk_size = 100
|
||
ansible_account_policy = 'privileged_first'
|
||
ansible_account_prefer = 'root,Administrator'
|
||
|
||
def __init__(self, execution):
|
||
super().__init__(execution)
|
||
self.method_id_meta_mapper = {
|
||
method['id']: method
|
||
for method in self.platform_automation_methods
|
||
if method['method'] == self.__class__.method_type()
|
||
}
|
||
# 根据执行方式就行分组, 不同资产的改密、推送等操作可能会使用不同的执行方式
|
||
# 然后根据执行方式分组, 再根据 bulk_size 分组, 生成不同的 playbook
|
||
self.playbooks = []
|
||
params = self.execution.snapshot.get('params')
|
||
self.params = params or {}
|
||
|
||
def get_params(self, automation, method_type):
|
||
method_attr = '{}_method'.format(method_type)
|
||
method_params = '{}_params'.format(method_type)
|
||
method_id = getattr(automation, method_attr)
|
||
automation_params = getattr(automation, method_params)
|
||
serializer = self.method_id_meta_mapper[method_id]['params_serializer']
|
||
|
||
if serializer is None:
|
||
return {}
|
||
|
||
data = self.params.get(method_id)
|
||
if not data:
|
||
data = automation_params.get(method_id, {})
|
||
params = serializer(data).data
|
||
return params
|
||
|
||
@property
|
||
def platform_automation_methods(self):
|
||
return platform_automation_methods
|
||
|
||
@classmethod
|
||
def method_type(cls):
|
||
raise NotImplementedError
|
||
|
||
def get_assets_group_by_platform(self):
|
||
return self.execution.all_assets_group_by_platform()
|
||
|
||
def prepare_runtime_dir(self):
|
||
ansible_dir = settings.ANSIBLE_DIR
|
||
task_name = self.execution.snapshot['name']
|
||
dir_name = '{}_{}'.format(task_name.replace(' ', '_'), self.execution.id)
|
||
path = os.path.join(
|
||
ansible_dir, 'automations', self.execution.snapshot['type'],
|
||
dir_name, timezone.now().strftime('%Y%m%d_%H%M%S')
|
||
)
|
||
if not os.path.exists(path):
|
||
os.makedirs(path, exist_ok=True, mode=0o755)
|
||
return path
|
||
|
||
@lazyproperty
|
||
def runtime_dir(self):
|
||
path = self.prepare_runtime_dir()
|
||
if settings.DEBUG_DEV:
|
||
msg = 'Ansible runtime dir: {}'.format(path)
|
||
print(msg)
|
||
return path
|
||
|
||
@staticmethod
|
||
def write_cert_to_file(filename, content):
|
||
with open(filename, 'w') as f:
|
||
f.write(content)
|
||
return filename
|
||
|
||
def convert_cert_to_file(self, host, path_dir):
|
||
if not path_dir:
|
||
return host
|
||
|
||
specific = host.get('jms_asset', {}).get('secret_info', {})
|
||
cert_fields = ('ca_cert', 'client_key', 'client_cert')
|
||
filtered = list(filter(lambda x: specific.get(x), cert_fields))
|
||
if not filtered:
|
||
return host
|
||
|
||
cert_dir = os.path.join(path_dir, 'certs')
|
||
if not os.path.exists(cert_dir):
|
||
os.makedirs(cert_dir, 0o700, True)
|
||
|
||
for f in filtered:
|
||
result = self.write_cert_to_file(
|
||
os.path.join(cert_dir, f), specific.get(f)
|
||
)
|
||
host['jms_asset']['secret_info'][f] = result
|
||
return host
|
||
|
||
def host_callback(self, host, automation=None, **kwargs):
|
||
method_type = self.__class__.method_type()
|
||
enabled_attr = '{}_enabled'.format(method_type)
|
||
method_attr = '{}_method'.format(method_type)
|
||
|
||
method_enabled = (
|
||
automation
|
||
and getattr(automation, enabled_attr)
|
||
and getattr(automation, method_attr)
|
||
and getattr(automation, method_attr) in self.method_id_meta_mapper
|
||
)
|
||
|
||
if not method_enabled:
|
||
host['error'] = _('{} disabled'.format(self.__class__.method_type()))
|
||
return host
|
||
|
||
host = self.convert_cert_to_file(host, kwargs.get('path_dir'))
|
||
host['params'] = self.get_params(automation, method_type)
|
||
return host
|
||
|
||
@staticmethod
|
||
def generate_public_key(private_key):
|
||
return ssh_pubkey_gen(private_key=private_key, hostname=gethostname())
|
||
|
||
@staticmethod
|
||
def generate_private_key_path(secret, path_dir):
|
||
key_name = '.' + hashlib.md5(secret.encode('utf-8')).hexdigest()
|
||
key_path = os.path.join(path_dir, key_name)
|
||
|
||
if not os.path.exists(key_path):
|
||
# https://github.com/ansible/ansible-runner/issues/544
|
||
# ssh requires OpenSSH format keys to have a full ending newline.
|
||
# It does not require this for old-style PEM keys.
|
||
with open(key_path, 'w') as f:
|
||
f.write(secret)
|
||
if is_openssh_format_key(secret.encode('utf-8')):
|
||
f.write("\n")
|
||
os.chmod(key_path, 0o400)
|
||
return key_path
|
||
|
||
def generate_inventory(self, platformed_assets, inventory_path, protocol):
|
||
inventory = JMSInventory(
|
||
assets=platformed_assets,
|
||
account_prefer=self.ansible_account_prefer,
|
||
account_policy=self.ansible_account_policy,
|
||
host_callback=self.host_callback,
|
||
task_type=self.__class__.method_type(),
|
||
protocol=protocol,
|
||
)
|
||
inventory.write_to_file(inventory_path)
|
||
|
||
@staticmethod
|
||
def generate_playbook(method, sub_playbook_dir):
|
||
method_playbook_dir_path = method['dir']
|
||
sub_playbook_path = os.path.join(sub_playbook_dir, 'project', 'main.yml')
|
||
shutil.copytree(method_playbook_dir_path, os.path.dirname(sub_playbook_path))
|
||
|
||
with open(sub_playbook_path, 'r') as f:
|
||
plays = yaml.safe_load(f)
|
||
for play in plays:
|
||
play['hosts'] = 'all'
|
||
|
||
with open(sub_playbook_path, 'w') as f:
|
||
yaml.safe_dump(plays, f)
|
||
return sub_playbook_path
|
||
|
||
def get_runners(self):
|
||
assets_group_by_platform = self.get_assets_group_by_platform()
|
||
if settings.DEBUG_DEV:
|
||
msg = 'Assets group by platform: {}'.format(dict(assets_group_by_platform))
|
||
print(msg)
|
||
|
||
runners = []
|
||
for platform, assets in assets_group_by_platform.items():
|
||
if not assets:
|
||
continue
|
||
if not platform.automation or not platform.automation.ansible_enabled:
|
||
print(_(" - Platform {} ansible disabled").format(platform.name))
|
||
continue
|
||
|
||
assets_bulked = [assets[i:i + self.bulk_size] for i in range(0, len(assets), self.bulk_size)]
|
||
for i, _assets in enumerate(assets_bulked, start=1):
|
||
sub_dir = '{}_{}'.format(platform.name, i)
|
||
playbook_dir = os.path.join(self.runtime_dir, sub_dir)
|
||
inventory_path = os.path.join(self.runtime_dir, sub_dir, 'hosts.json')
|
||
|
||
method_id = getattr(platform.automation, '{}_method'.format(self.__class__.method_type()))
|
||
method = self.method_id_meta_mapper.get(method_id)
|
||
|
||
if not method:
|
||
logger.error("Method not found: {}".format(method_id))
|
||
continue
|
||
|
||
protocol = method.get('protocol')
|
||
self.generate_inventory(_assets, inventory_path, protocol)
|
||
playbook_path = self.generate_playbook(method, playbook_dir)
|
||
if not playbook_path:
|
||
continue
|
||
|
||
runer = SuperPlaybookRunner(
|
||
inventory_path,
|
||
playbook_path,
|
||
self.runtime_dir,
|
||
callback=PlaybookCallback(),
|
||
)
|
||
|
||
with open(inventory_path, 'r') as f:
|
||
inventory_data = json.load(f)
|
||
if not inventory_data['all'].get('hosts'):
|
||
continue
|
||
|
||
runners.append(runer)
|
||
return runners
|
||
|
||
def on_host_success(self, host, result):
|
||
pass
|
||
|
||
def on_host_error(self, host, error, result):
|
||
if settings.DEBUG_DEV:
|
||
print('host error: {} -> {}'.format(host, error))
|
||
|
||
def _on_host_success(self, host, result, hosts):
|
||
self.on_host_success(host, result.get("ok", ''))
|
||
|
||
def _on_host_error(self, host, result, hosts):
|
||
error = hosts.get(host, '')
|
||
detail = result.get('failures', '') or result.get('dark', '')
|
||
self.on_host_error(host, error, detail)
|
||
|
||
def on_runner_success(self, runner, cb):
|
||
summary = cb.summary
|
||
for state, hosts in summary.items():
|
||
# 错误行为为,host 是 dict, ok 时是 list
|
||
|
||
if state == 'ok':
|
||
handler = self._on_host_success
|
||
elif state == 'skipped':
|
||
continue
|
||
else:
|
||
handler = self._on_host_error
|
||
|
||
for host in hosts:
|
||
result = cb.host_results.get(host)
|
||
handler(host, result, hosts)
|
||
|
||
def on_runner_failed(self, runner, e):
|
||
print("Runner failed: {} {}".format(e, self))
|
||
|
||
def delete_runtime_dir(self):
|
||
if settings.DEBUG_DEV:
|
||
return
|
||
shutil.rmtree(self.runtime_dir, ignore_errors=True)
|
||
|
||
def do_run(self, *args, **kwargs):
|
||
print(_(">>> Task preparation phase"), end="\n")
|
||
runners = self.get_runners()
|
||
if len(runners) > 1:
|
||
print(_(">>> Executing tasks in batches, total {runner_count}").format(runner_count=len(runners)))
|
||
elif len(runners) == 1:
|
||
print(_(">>> Start executing tasks"))
|
||
else:
|
||
print(_(">>> No tasks need to be executed"), end="\n")
|
||
|
||
for i, runner in enumerate(runners, start=1):
|
||
if len(runners) > 1:
|
||
print(_(">>> Begin executing batch {index} of tasks").format(index=i))
|
||
ssh_tunnel = SSHTunnelManager()
|
||
ssh_tunnel.local_gateway_prepare(runner)
|
||
|
||
try:
|
||
kwargs.update({"clean_workspace": False})
|
||
cb = runner.run(**kwargs)
|
||
self.on_runner_success(runner, cb)
|
||
except Exception as e:
|
||
self.on_runner_failed(runner, e)
|
||
finally:
|
||
ssh_tunnel.local_gateway_clean(runner)
|
||
print('\n')
|
||
|
||
|