perf: update pam

pull/14517/head
ibuler 2024-11-18 19:06:04 +08:00
parent ca7d2130a5
commit 6f149b7c11
12 changed files with 404 additions and 242 deletions

View File

@ -1,6 +1,8 @@
from django.http import HttpResponse
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import status, mixins, viewsets from rest_framework import status, mixins, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from accounts.models import AutomationExecution from accounts.models import AutomationExecution
@ -98,7 +100,6 @@ class AutomationExecutionViewSet(
search_fields = ('trigger', 'automation__name') search_fields = ('trigger', 'automation__name')
filterset_fields = ('trigger', 'automation_id', 'automation__name') filterset_fields = ('trigger', 'automation_id', 'automation__name')
serializer_class = serializers.AutomationExecutionSerializer serializer_class = serializers.AutomationExecutionSerializer
tp: str tp: str
def get_queryset(self): def get_queryset(self):
@ -113,3 +114,10 @@ class AutomationExecutionViewSet(
pid=str(automation.pk), trigger=Trigger.manual, tp=self.tp pid=str(automation.pk), trigger=Trigger.manual, tp=self.tp
) )
return Response({'task': task.id}, status=status.HTTP_201_CREATED) return Response({'task': task.id}, status=status.HTTP_201_CREATED)
@action(methods=['get'], detail=True, url_path='report')
def report(self, request, *args, **kwargs):
execution = self.get_object()
report = execution.manager.gen_report()
return HttpResponse(report)

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models import Q, Count from django.db.models import Q, Count
from django.http import HttpResponse
from rest_framework.decorators import action from rest_framework.decorators import action
from accounts import serializers from accounts import serializers
@ -39,12 +38,6 @@ class CheckAccountExecutionViewSet(AutomationExecutionViewSet):
queryset = queryset.filter(automation__type=self.tp) queryset = queryset.filter(automation__type=self.tp)
return queryset return queryset
@action(methods=['get'], detail=True, url_path='report')
def report(self, request, *args, **kwargs):
execution = self.get_object()
report = execution.manager.gen_report()
return HttpResponse(report)
class AccountRiskViewSet(OrgBulkModelViewSet): class AccountRiskViewSet(OrgBulkModelViewSet):
model = AccountRisk model = AccountRisk

View File

@ -32,6 +32,7 @@ class GatherAccountsExecutionViewSet(AutomationExecutionViewSet):
("list", "accounts.view_gatheraccountsexecution"), ("list", "accounts.view_gatheraccountsexecution"),
("retrieve", "accounts.view_gatheraccountsexecution"), ("retrieve", "accounts.view_gatheraccountsexecution"),
("create", "accounts.add_gatheraccountsexecution"), ("create", "accounts.add_gatheraccountsexecution"),
("report", "accounts.view_gatheraccountsexecution"),
) )
tp = AutomationTypes.gather_accounts tp = AutomationTypes.gather_accounts

View File

@ -2,22 +2,15 @@
# #
import time import time
from django.utils import timezone
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from assets.automations.base.manager import BaseManager
from common.db.utils import safe_db_connection
from common.utils.timezone import local_now_display from common.utils.timezone import local_now_display
from .handlers import AccountBackupHandler from .handlers import AccountBackupHandler
class AccountBackupManager: class AccountBackupManager(BaseManager):
def __init__(self, execution):
self.execution = execution
self.date_start = timezone.now()
self.time_start = time.time()
self.date_end = None
self.time_end = None
self.timedelta = 0
def do_run(self): def do_run(self):
execution = self.execution execution = self.execution
account_backup_execution_being_executed = _('The account backup plan is being executed') account_backup_execution_being_executed = _('The account backup plan is being executed')
@ -25,24 +18,19 @@ class AccountBackupManager:
handler = AccountBackupHandler(execution) handler = AccountBackupHandler(execution)
handler.run() handler.run()
def pre_run(self): def send_report_if_need(self):
self.execution.date_start = self.date_start pass
self.execution.save()
def post_run(self): def update_execution(self):
self.time_end = time.time() timedelta = int(time.time() - self.time_start)
self.date_end = timezone.now() self.execution.timedelta = timedelta
with safe_db_connection():
self.execution.save(update_fields=['timedelta', ])
def print_summary(self):
print('\n\n' + '-' * 80) print('\n\n' + '-' * 80)
plan_execution_end = _('Plan execution end') plan_execution_end = _('Plan execution end')
print('{} {}\n'.format(plan_execution_end, local_now_display())) print('{} {}\n'.format(plan_execution_end, local_now_display()))
self.timedelta = self.time_end - self.time_start
time_cost = _('Time cost') time_cost = _('Time cost')
print('{}: {}s'.format(time_cost, self.timedelta)) print('{}: {}s'.format(time_cost, self.duration))
self.execution.timedelta = self.timedelta
self.execution.save()
def run(self):
self.pre_run()
self.do_run()
self.post_run()

View File

@ -1,14 +1,10 @@
import re import re
import time
from collections import defaultdict from collections import defaultdict
from django.template.loader import render_to_string
from django.utils import timezone from django.utils import timezone
from premailer import transform
from accounts.models import Account, AccountRisk from accounts.models import Account, AccountRisk
from common.db.utils import safe_db_connection from assets.automations.base.manager import BaseManager
from common.tasks import send_mail_async
from common.utils.strings import color_fmt from common.utils.strings import color_fmt
@ -91,32 +87,28 @@ def check_account_secrets(accounts, assets):
return summary, result return summary, result
class CheckAccountManager: class CheckAccountManager(BaseManager):
batch_size=100
def __init__(self, execution): def __init__(self, execution):
self.execution = execution super().__init__(execution)
self.date_start = timezone.now() self.accounts = []
self.time_start = time.time()
self.date_end = None
self.time_end = None
self.timedelta = 0
self.assets = [] self.assets = []
self.summary = {}
self.result = defaultdict(list)
def pre_run(self): def pre_run(self):
self.assets = self.execution.get_all_assets() self.assets = self.execution.get_all_assets()
self.execution.date_start = timezone.now() self.execution.date_start = timezone.now()
self.execution.save(update_fields=['date_start']) self.execution.save(update_fields=['date_start'])
def batch_run(self, batch_size=100): def do_run(self, *args, **kwargs):
for engine in self.execution.snapshot.get('engines', []): for engine in self.execution.snapshot.get('engines', []):
if engine == 'check_account_secret': if engine == 'check_account_secret':
handle = check_account_secrets handle = check_account_secrets
else: else:
continue continue
for i in range(0, len(self.assets), batch_size): for i in range(0, len(self.assets), self.batch_size):
_assets = self.assets[i:i + batch_size] _assets = self.assets[i:i + self.batch_size]
accounts = Account.objects.filter(asset__in=_assets) accounts = Account.objects.filter(asset__in=_assets)
summary, result = handle(accounts, _assets) summary, result = handle(accounts, _assets)
@ -125,51 +117,8 @@ class CheckAccountManager:
for k, v in result.items(): for k, v in result.items():
self.result[k].extend(v) self.result[k].extend(v)
def _update_execution_and_summery(self): def print_summary(self):
self.date_end = timezone.now()
self.time_end = time.time()
self.duration = self.time_end - self.time_start
self.execution.date_finished = timezone.now()
self.execution.status = 'success'
self.execution.summary = self.summary
self.execution.result = self.result
with safe_db_connection():
self.execution.save(update_fields=['date_finished', 'status', 'summary', 'result'])
def after_run(self):
self._update_execution_and_summery()
self._send_report()
tmpl = "\n---\nSummary: \nok: %s, weak password: %s, no secret: %s, using time: %ss" % ( tmpl = "\n---\nSummary: \nok: %s, weak password: %s, no secret: %s, using time: %ss" % (
self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.timedelta) self.summary['ok'], self.summary['weak_password'], self.summary['no_secret'], int(self.timedelta)
) )
print(tmpl) print(tmpl)
def gen_report(self):
template_path = 'accounts/check_account_report.html'
context = {
'execution': self.execution,
'summary': self.execution.summary,
'result': self.execution.result
}
data = render_to_string(template_path, context)
return data
def _send_report(self):
recipients = self.execution.recipients
if not recipients:
return
report = self.gen_report()
report = transform(report)
print("Send resport to: {}".format([str(r) for r in recipients]))
subject = f'Check account automation {self.execution.id} finished'
emails = [r.email for r in recipients if r.email]
send_mail_async(subject, report, emails, html_message=report)
def run(self,):
self.pre_run()
self.batch_run()
self.after_run()

View File

@ -16,10 +16,34 @@ from ...notifications import GatherAccountChangeMsg
logger = get_logger(__name__) logger = get_logger(__name__)
class GatherAccountsManager(AccountBasePlaybookManager): diff_items = [
diff_items = [ 'authorized_keys', 'sudoers', 'groups',
'authorized_keys', 'sudoers', 'groups', ]
]
def get_items_diff(ori_account, d):
if hasattr(ori_account, '_diff'):
return ori_account._diff
diff = {}
for item in diff_items:
ori = getattr(ori_account, item)
new = d.get(item, '')
if not ori:
continue
if isinstance(new, timezone.datetime):
new = ori.strftime('%Y-%m-%d %H:%M:%S')
ori = ori.strftime('%Y-%m-%d %H:%M:%S')
if new != ori:
diff[item] = get_text_diff(ori, new)
ori_account._diff = diff
return diff
class AnalyseAccountRisk:
long_time = timezone.timedelta(days=90) long_time = timezone.timedelta(days=90)
datetime_check_items = [ datetime_check_items = [
{'field': 'date_last_login', 'risk': 'zombie', 'delta': long_time}, {'field': 'date_last_login', 'risk': 'zombie', 'delta': long_time},
@ -27,20 +51,101 @@ class GatherAccountsManager(AccountBasePlaybookManager):
{'field': 'date_password_expired', 'risk': 'password_expired', 'delta': timezone.timedelta(seconds=1)} {'field': 'date_password_expired', 'risk': 'password_expired', 'delta': timezone.timedelta(seconds=1)}
] ]
def __init__(self, check_risk=True):
self.check_risk = check_risk
self.now = timezone.now()
self.pending_add_risks = []
def _analyse_item_changed(self, ori_account, d):
diff = get_items_diff(ori_account, d)
if not diff:
return
for k, v in diff.items():
self.pending_add_risks.append(dict(
asset=ori_account.asset, username=ori_account.username,
risk=k+'_changed', detail={'diff': v}
))
def perform_save_risks(self, risks):
# 提前取出来,避免每次都查数据库
assets = {r['asset'] for r in risks}
assets_risks = AccountRisk.objects.filter(asset__in=assets)
assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks}
for d in risks:
detail = d.pop('detail', {})
detail['datetime'] = self.now.isoformat()
key = f"{d['asset'].id}_{d['username']}_{d['risk']}"
found = assets_risks.get(key)
if not found:
r = AccountRisk(**d, details=[detail])
r.save()
continue
found.details.append(detail)
found.save(update_fields=['details'])
def _analyse_datetime_changed(self, ori_account, d, asset, username):
basic = {'asset': asset, 'username': username}
for item in self.datetime_check_items:
field = item['field']
risk = item['risk']
delta = item['delta']
date = d.get(field)
if not date:
continue
pre_date = ori_account and getattr(ori_account, field)
if pre_date == date:
continue
if date and date < timezone.now() - delta:
self.pending_add_risks.append(
dict(**basic, risk=risk, detail={'date': date.isoformat()})
)
def batch_analyse_risk(self, asset, ori_account, d, batch_size=20):
if not self.check_risk:
return
if asset is None:
if self.pending_add_risks:
self.perform_save_risks(self.pending_add_risks)
self.pending_add_risks = []
return
basic = {'asset': asset, 'username': d['username']}
if ori_account:
self._analyse_item_changed(ori_account, d)
else:
self.pending_add_risks.append(
dict(**basic, risk='ghost')
)
self._analyse_datetime_changed(ori_account, d, asset, d['username'])
if len(self.pending_add_risks) > batch_size:
self.batch_analyse_risk(None, None, {})
class GatherAccountsManager(AccountBasePlaybookManager):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.host_asset_mapper = {} self.host_asset_mapper = {}
self.asset_account_info = {} self.asset_account_info = {}
self.pending_add_accounts = []
self.pending_update_accounts = []
self.asset_usernames_mapper = defaultdict(set) self.asset_usernames_mapper = defaultdict(set)
self.ori_asset_usernames = defaultdict(set) self.ori_asset_usernames = defaultdict(set)
self.ori_gathered_usernames = defaultdict(set) self.ori_gathered_usernames = defaultdict(set)
self.ori_gathered_accounts_mapper = dict() self.ori_gathered_accounts_mapper = dict()
self.is_sync_account = self.execution.snapshot.get('is_sync_account') self.is_sync_account = self.execution.snapshot.get('is_sync_account')
self.pending_add_accounts = [] self.check_risk = self.execution.snapshot.get('check_risk', False)
self.pending_update_accounts = []
self.pending_add_risks = []
self.now = timezone.now()
@classmethod @classmethod
def method_type(cls): def method_type(cls):
@ -168,109 +273,14 @@ class GatherAccountsManager(AccountBasePlaybookManager):
if len(self.pending_add_accounts) > batch_size: if len(self.pending_add_accounts) > batch_size:
self.batch_create_gathered_account(None) self.batch_create_gathered_account(None)
def _analyse_item_changed(self, ori_account, d):
diff = self.get_items_diff(ori_account, d)
if not diff:
return
for k, v in diff.items():
self.pending_add_risks.append(dict(
asset=ori_account.asset, username=ori_account.username,
risk=k+'_changed', detail={'diff': v}
))
def perform_save_risks(self, risks):
# 提前取出来,避免每次都查数据库
assets = {r['asset'] for r in risks}
assets_risks = AccountRisk.objects.filter(asset__in=assets)
assets_risks = {f"{r.asset_id}_{r.username}_{r.risk}": r for r in assets_risks}
for d in risks:
detail = d.pop('detail', {})
detail['datetime'] = self.now.isoformat()
key = f"{d['asset'].id}_{d['username']}_{d['risk']}"
found = assets_risks.get(key)
if not found:
r = AccountRisk(**d, details=[detail])
r.save()
continue
found.details.append(detail)
found.save(update_fields=['details'])
def _analyse_datetime_changed(self, ori_account, d, asset, username):
basic = {'asset': asset, 'username': username}
for item in self.datetime_check_items:
field = item['field']
risk = item['risk']
delta = item['delta']
date = d.get(field)
if not date:
continue
pre_date = ori_account and getattr(ori_account, field)
if pre_date == date:
continue
if date and date < timezone.now() - delta:
self.pending_add_risks.append(
dict(**basic, risk=risk, detail={'date': date.isoformat()})
)
def batch_analyse_risk(self, asset, ori_account, d, batch_size=20):
if asset is None:
if self.pending_add_risks:
self.perform_save_risks(self.pending_add_risks)
self.pending_add_risks = []
return
basic = {'asset': asset, 'username': d['username']}
if ori_account:
self._analyse_item_changed(ori_account, d)
else:
self.pending_add_risks.append(
dict(**basic, risk='ghost')
)
self._analyse_datetime_changed(ori_account, d, asset, d['username'])
if len(self.pending_add_risks) > batch_size:
self.batch_analyse_risk(None, None, {})
def get_items_diff(self, ori_account, d):
if hasattr(ori_account, '_diff'):
return ori_account._diff
diff = {}
for item in self.diff_items:
ori = getattr(ori_account, item)
new = d.get(item, '')
if not ori:
continue
if isinstance(new, timezone.datetime):
new = ori.strftime('%Y-%m-%d %H:%M:%S')
ori = ori.strftime('%Y-%m-%d %H:%M:%S')
if new != ori:
diff[item] = get_text_diff(ori, new)
ori_account._diff = diff
return diff
def batch_update_gathered_account(self, ori_account, d, batch_size=20): def batch_update_gathered_account(self, ori_account, d, batch_size=20):
if not ori_account or d is None: if not ori_account or d is None:
if self.pending_update_accounts: if self.pending_update_accounts:
GatheredAccount.objects.bulk_update(self.pending_update_accounts, [*self.diff_items]) GatheredAccount.objects.bulk_update(self.pending_update_accounts, [*diff_items])
self.pending_update_accounts = [] self.pending_update_accounts = []
return return
diff = self.get_items_diff(ori_account, d) diff = get_items_diff(ori_account, d)
if diff: if diff:
for k in diff: for k in diff:
setattr(ori_account, k, d[k]) setattr(ori_account, k, d[k])
@ -279,7 +289,9 @@ class GatherAccountsManager(AccountBasePlaybookManager):
if len(self.pending_update_accounts) > batch_size: if len(self.pending_update_accounts) > batch_size:
self.batch_update_gathered_account(None, None) self.batch_update_gathered_account(None, None)
def update_or_create_accounts(self): def do_run(self):
risk_analyser = AnalyseAccountRisk(self.check_risk)
for asset, accounts_data in self.asset_account_info.items(): for asset, accounts_data in self.asset_account_info.items():
with (tmp_to_org(asset.org_id)): with (tmp_to_org(asset.org_id)):
gathered_accounts = [] gathered_accounts = []
@ -291,21 +303,18 @@ class GatherAccountsManager(AccountBasePlaybookManager):
self.batch_create_gathered_account(d) self.batch_create_gathered_account(d)
else: else:
self.batch_update_gathered_account(ori_account, d) self.batch_update_gathered_account(ori_account, d)
risk_analyser.batch_analyse_risk(asset, ori_account, d)
self.batch_analyse_risk(asset, ori_account, d)
self.update_gather_accounts_status(asset) self.update_gather_accounts_status(asset)
GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account) GatheredAccount.sync_accounts(gathered_accounts, self.is_sync_account)
self.batch_create_gathered_account(None) self.batch_create_gathered_account(None)
self.batch_update_gathered_account(None, None) self.batch_update_gathered_account(None, None)
self.batch_analyse_risk(None, None, {}) risk_analyser.batch_analyse_risk(None, None, {})
def run(self, *args, **kwargs): def before_run(self):
super().run(*args, **kwargs) super().before_run()
self.prefetch_origin_account_usernames() self.prefetch_origin_account_usernames()
self.update_or_create_accounts()
# self.send_email_if_need()
def generate_send_users_and_change_info(self): def generate_send_users_and_change_info(self):
recipients = self.execution.recipients recipients = self.execution.recipients

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-11-18 03:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("accounts", "0013_checkaccountautomation_recipients"),
]
operations = [
migrations.AddField(
model_name="gatheraccountsautomation",
name="check_risk",
field=models.BooleanField(default=True, verbose_name="Check risk"),
),
]

View File

@ -17,9 +17,6 @@ class CheckAccountAutomation(AccountBaseAutomation):
engines = models.ManyToManyField('CheckAccountEngine', related_name='check_automations', verbose_name=_('Engines')) engines = models.ManyToManyField('CheckAccountEngine', related_name='check_automations', verbose_name=_('Engines'))
recipients = models.ManyToManyField('users.User', verbose_name=_("Recipient"), blank=True) recipients = models.ManyToManyField('users.User', verbose_name=_("Recipient"), blank=True)
def get_report_template(self):
return 'accounts/check_account_report.html'
def to_attr_json(self): def to_attr_json(self):
attr_json = super().to_attr_json() attr_json = super().to_attr_json()
attr_json.update({ attr_json.update({

View File

@ -91,11 +91,13 @@ class GatherAccountsAutomation(AccountBaseAutomation):
default=False, blank=True, verbose_name=_("Is sync account") default=False, blank=True, verbose_name=_("Is sync account")
) )
recipients = models.ManyToManyField('users.User', verbose_name=_("Recipient"), blank=True) recipients = models.ManyToManyField('users.User', verbose_name=_("Recipient"), blank=True)
check_risk = models.BooleanField(default=True, verbose_name=_("Check risk"))
def to_attr_json(self): def to_attr_json(self):
attr_json = super().to_attr_json() attr_json = super().to_attr_json()
attr_json.update({ attr_json.update({
'is_sync_account': self.is_sync_account, 'is_sync_account': self.is_sync_account,
'check_risk': self.check_risk,
'recipients': [ 'recipients': [
str(recipient.id) for recipient in self.recipients.all() str(recipient.id) for recipient in self.recipients.all()
] ]

View File

@ -19,9 +19,15 @@ class GatherAccountAutomationSerializer(BaseAutomationSerializer):
class Meta: class Meta:
model = GatherAccountsAutomation model = GatherAccountsAutomation
read_only_fields = BaseAutomationSerializer.Meta.read_only_fields read_only_fields = BaseAutomationSerializer.Meta.read_only_fields
fields = BaseAutomationSerializer.Meta.fields \ fields = (BaseAutomationSerializer.Meta.fields
+ ['is_sync_account', 'recipients'] + read_only_fields + ['is_sync_account', 'check_risk', 'recipients']
extra_kwargs = BaseAutomationSerializer.Meta.extra_kwargs + read_only_fields)
extra_kwargs = {
'check_risk': {
'help_text': _('Whether to check the risk of the gathered accounts.'),
},
**BaseAutomationSerializer.Meta.extra_kwargs
}
@property @property
def model_type(self): def model_type(self):

View File

@ -0,0 +1,108 @@
{% load i18n %}
<div class='summary'>
<p>{% trans 'The following is a summary of the account check tasks. Please review and handle them' %}</p>
<table>
<thead>
<tr>
<th colspan='2'>任务汇总: </th>
</tr>
</thead>
<tbody>
<tr>
<td>{% trans 'Task name' %}: </td>
<td>{{ execution.automation.name }} </td>
</tr>
<tr>
<td>{% trans 'Date start' %}: </td>
<td>{{ execution.date_start }}</td>
</tr>
<tr>
<td>{% trans 'Date end' %}: </td>
<td>{{ execution.date_finished }}</td>
</tr>
<tr>
<td>{% trans 'Time using' %}: </td>
<td>{{ execution.duration }}s</td>
</tr>
<tr>
<td>{% trans 'Assets count' %}: </td>
<td>{{ summary.assets }}</td>
</tr>
<tr>
<td>{% trans 'Account count' %}: </td>
<td>{{ summary.accounts }}</td>
</tr>
<tr>
<td>{% trans 'Week password count' %}:</td>
<td> <span> {{ summary.weak_password }}</span></td>
</tr>
<tr>
<td>{% trans 'Ok count' %}: </td>
<td>{{ summary.ok }}</td>
</tr>
<tr>
<td>{% trans 'No password count' %}: </td>
<td>{{ summary.no_secret }}</td>
</tr>
</tbody>
</table>
</div>
<div class='result'>
<p>{% trans 'Account check details' %}:</p>
<table style="">
<thead>
<tr>
<th>{% trans 'No.' %}</th>
<th>{% trans 'Asset' %}</th>
<th>{% trans 'Username' %}</th>
<th>{% trans 'Result' %}</th>
</tr>
</thead>
<tbody>
{% for account in result.weak_password %}
<tr>
<td>{{ forloop.counter }}</td>
<td>{{ account.asset }}</td>
<td>{{ account.username }}</td>
<td>{% trans 'Week password' %}</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<style>
table {
width: 100%;
border-collapse: collapse;
max-width: 100%;
text-align: left;
margin-top: 20px;
padding: 20px;
}
th {
background: #f2f2f2;
font-size: 14px;
padding: 5px;
border: 1px solid #ddd;
}
tr :first-child {
width: 30%;
}
td {
border: 1px solid #ddd;
padding: 5px;
font-size: 12px;
}
.result tr :first-child {
width: 10%;
}
</style>

View File

@ -2,15 +2,21 @@ import hashlib
import json import json
import os import os
import shutil import shutil
import time
from collections import defaultdict
from socket import gethostname from socket import gethostname
import yaml import yaml
from django.conf import settings from django.conf import settings
from django.template.loader import render_to_string
from django.utils import timezone from django.utils import timezone
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from premailer import transform
from sshtunnel import SSHTunnelForwarder from sshtunnel import SSHTunnelForwarder
from assets.automations.methods import platform_automation_methods from assets.automations.methods import platform_automation_methods
from common.db.utils import safe_db_connection
from common.tasks import send_mail_async
from common.utils import get_logger, lazyproperty, is_openssh_format_key, ssh_pubkey_gen from common.utils import get_logger, lazyproperty, is_openssh_format_key, ssh_pubkey_gen
from ops.ansible import JMSInventory, DefaultCallback, SuperPlaybookRunner from ops.ansible import JMSInventory, DefaultCallback, SuperPlaybookRunner
from ops.ansible.interface import interface from ops.ansible.interface import interface
@ -81,13 +87,87 @@ class PlaybookCallback(DefaultCallback):
super().playbook_on_stats(event_data, **kwargs) super().playbook_on_stats(event_data, **kwargs)
class BasePlaybookManager: class BaseManager:
def __init__(self, execution):
self.execution = execution
self.time_start = time.time()
self.summary = defaultdict(int)
self.result = defaultdict(list)
self.duration = 0
def get_assets_group_by_platform(self):
return self.execution.all_assets_group_by_platform()
def before_run(self):
self.execution.date_start = timezone.now()
self.execution.save(update_fields=['date_start'])
def get_report_subject(self):
return f'Automation {self.execution.id} finished'
def send_report_if_need(self):
recipients = self.execution.recipients
if not recipients:
return
report = self.gen_report()
report = transform(report)
subject = self.get_report_subject()
print("Send resport to: {}".format([str(r) for r in recipients]))
emails = [r.email for r in recipients if r.email]
send_mail_async(subject, report, emails, html_message=report)
def update_execution(self):
self.duration = int(time.time() - self.time_start)
self.execution.date_finished = timezone.now()
self.execution.duration = self.duration
self.execution.summary = self.summary
self.execution.result = self.result
with safe_db_connection():
self.execution.save(update_fields=['date_finished', 'duration', 'summary', 'result'])
def print_summary(self):
pass
def get_template_path(self):
raise NotImplementedError
def gen_report(self):
template_path = self.get_template_path()
context = {
'execution': self.execution,
'summary': self.execution.summary,
'result': self.execution.result
}
data = render_to_string(template_path, context)
return data
def after_run(self):
self.update_execution()
self.print_summary()
self.send_report_if_need()
def run(self, *args, **kwargs):
self.before_run()
self.do_run(*args, **kwargs)
self.after_run()
def do_run(self, *args, **kwargs):
raise NotImplementedError
@staticmethod
def json_dumps(data):
return json.dumps(data, indent=4, sort_keys=True)
class BasePlaybookManager(BaseManager):
bulk_size = 100 bulk_size = 100
ansible_account_policy = 'privileged_first' ansible_account_policy = 'privileged_first'
ansible_account_prefer = 'root,Administrator' ansible_account_prefer = 'root,Administrator'
def __init__(self, execution): def __init__(self, execution):
self.execution = execution super().__init__(execution)
self.method_id_meta_mapper = { self.method_id_meta_mapper = {
method['id']: method method['id']: method
for method in self.platform_automation_methods for method in self.platform_automation_methods
@ -178,10 +258,12 @@ class BasePlaybookManager:
enabled_attr = '{}_enabled'.format(method_type) enabled_attr = '{}_enabled'.format(method_type)
method_attr = '{}_method'.format(method_type) method_attr = '{}_method'.format(method_type)
method_enabled = automation and \ method_enabled = (
getattr(automation, enabled_attr) and \ automation
getattr(automation, method_attr) and \ and getattr(automation, enabled_attr)
getattr(automation, method_attr) in self.method_id_meta_mapper and getattr(automation, method_attr)
and getattr(automation, method_attr) in self.method_id_meta_mapper
)
if not method_enabled: if not method_enabled:
host['error'] = _('{} disabled'.format(self.__class__.method_type())) host['error'] = _('{} disabled'.format(self.__class__.method_type()))
@ -242,6 +324,7 @@ class BasePlaybookManager:
if settings.DEBUG_DEV: if settings.DEBUG_DEV:
msg = 'Assets group by platform: {}'.format(dict(assets_group_by_platform)) msg = 'Assets group by platform: {}'.format(dict(assets_group_by_platform))
print(msg) print(msg)
runners = [] runners = []
for platform, assets in assets_group_by_platform.items(): for platform, assets in assets_group_by_platform.items():
if not assets: if not assets:
@ -249,8 +332,8 @@ class BasePlaybookManager:
if not platform.automation or not platform.automation.ansible_enabled: if not platform.automation or not platform.automation.ansible_enabled:
print(_(" - Platform {} ansible disabled").format(platform.name)) print(_(" - Platform {} ansible disabled").format(platform.name))
continue continue
assets_bulked = [assets[i:i + self.bulk_size] for i in range(0, len(assets), self.bulk_size)]
assets_bulked = [assets[i:i + self.bulk_size] for i in range(0, len(assets), self.bulk_size)]
for i, _assets in enumerate(assets_bulked, start=1): for i, _assets in enumerate(assets_bulked, start=1):
sub_dir = '{}_{}'.format(platform.name, i) sub_dir = '{}_{}'.format(platform.name, i)
playbook_dir = os.path.join(self.runtime_dir, sub_dir) playbook_dir = os.path.join(self.runtime_dir, sub_dir)
@ -262,6 +345,7 @@ class BasePlaybookManager:
if not method: if not method:
logger.error("Method not found: {}".format(method_id)) logger.error("Method not found: {}".format(method_id))
continue continue
protocol = method.get('protocol') protocol = method.get('protocol')
self.generate_inventory(_assets, inventory_path, protocol) self.generate_inventory(_assets, inventory_path, protocol)
playbook_path = self.generate_playbook(method, playbook_dir) playbook_path = self.generate_playbook(method, playbook_dir)
@ -290,36 +374,37 @@ class BasePlaybookManager:
if settings.DEBUG_DEV: if settings.DEBUG_DEV:
print('host error: {} -> {}'.format(host, error)) print('host error: {} -> {}'.format(host, error))
def _on_host_success(self, host, result, error, detail):
self.on_host_success(host, result)
def _on_host_error(self, host, result, error, detail):
self.on_host_error(host, error, detail)
def on_runner_success(self, runner, cb): def on_runner_success(self, runner, cb):
summary = cb.summary summary = cb.summary
for state, hosts in summary.items(): for state, hosts in summary.items():
if state == 'ok':
handler = self._on_host_success
elif state == 'skipped':
continue
else:
handler = self._on_host_error
for host in hosts: for host in hosts:
result = cb.host_results.get(host) result = cb.host_results.get(host)
if state == 'ok': error = hosts.get(host, '')
self.on_host_success(host, result.get('ok', '')) detail = result.get('failures', '') or result.get('dark', '')
elif state == 'skipped': handler(host, result, error, detail)
pass
else:
error = hosts.get(host)
self.on_host_error(
host, error,
result.get('failures', '')
or result.get('dark', '')
)
def on_runner_failed(self, runner, e): def on_runner_failed(self, runner, e):
print("Runner failed: {} {}".format(e, self)) print("Runner failed: {} {}".format(e, self))
@staticmethod
def json_dumps(data):
return json.dumps(data, indent=4, sort_keys=True)
def delete_runtime_dir(self): def delete_runtime_dir(self):
if settings.DEBUG_DEV: if settings.DEBUG_DEV:
return return
shutil.rmtree(self.runtime_dir, ignore_errors=True) shutil.rmtree(self.runtime_dir, ignore_errors=True)
def run(self, *args, **kwargs): def do_run(self, *args, **kwargs):
print(_(">>> Task preparation phase"), end="\n") print(_(">>> Task preparation phase"), end="\n")
runners = self.get_runners() runners = self.get_runners()
if len(runners) > 1: if len(runners) > 1:
@ -329,12 +414,12 @@ class BasePlaybookManager:
else: else:
print(_(">>> No tasks need to be executed"), end="\n") print(_(">>> No tasks need to be executed"), end="\n")
self.execution.date_start = timezone.now()
for i, runner in enumerate(runners, start=1): for i, runner in enumerate(runners, start=1):
if len(runners) > 1: if len(runners) > 1:
print(_(">>> Begin executing batch {index} of tasks").format(index=i)) print(_(">>> Begin executing batch {index} of tasks").format(index=i))
ssh_tunnel = SSHTunnelManager() ssh_tunnel = SSHTunnelManager()
ssh_tunnel.local_gateway_prepare(runner) ssh_tunnel.local_gateway_prepare(runner)
try: try:
kwargs.update({"clean_workspace": False}) kwargs.update({"clean_workspace": False})
cb = runner.run(**kwargs) cb = runner.run(**kwargs)
@ -344,7 +429,5 @@ class BasePlaybookManager:
finally: finally:
ssh_tunnel.local_gateway_clean(runner) ssh_tunnel.local_gateway_clean(runner)
print('\n') print('\n')
self.execution.status = 'success'
self.execution.date_finished = timezone.now()
self.execution.save()
self.delete_runtime_dir()