perf: update change status

pull/14578/head
ibuler 2024-12-03 17:49:04 +08:00
parent 528c333a6d
commit fc3fc40341
13 changed files with 300 additions and 125 deletions

View File

@ -9,14 +9,22 @@ from rest_framework.response import Response
from accounts import serializers from accounts import serializers
from accounts.const import AutomationTypes from accounts.const import AutomationTypes
from accounts.models import CheckAccountAutomation, AccountRisk, RiskChoice, CheckAccountEngine from accounts.models import (
CheckAccountAutomation,
AccountRisk,
RiskChoice,
CheckAccountEngine,
)
from common.api import JMSModelViewSet from common.api import JMSModelViewSet
from common.utils import many_get
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from .base import AutomationExecutionViewSet from .base import AutomationExecutionViewSet
__all__ = [ __all__ = [
'CheckAccountAutomationViewSet', 'CheckAccountExecutionViewSet', "CheckAccountAutomationViewSet",
'AccountRiskViewSet', 'CheckAccountEngineViewSet', "CheckAccountExecutionViewSet",
"AccountRiskViewSet",
"CheckAccountEngineViewSet",
] ]
from ...risk_handlers import RiskHandler from ...risk_handlers import RiskHandler
@ -24,7 +32,7 @@ from ...risk_handlers import RiskHandler
class CheckAccountAutomationViewSet(OrgBulkModelViewSet): class CheckAccountAutomationViewSet(OrgBulkModelViewSet):
model = CheckAccountAutomation model = CheckAccountAutomation
filterset_fields = ('name',) filterset_fields = ("name",)
search_fields = filterset_fields search_fields = filterset_fields
serializer_class = serializers.CheckAccountAutomationSerializer serializer_class = serializers.CheckAccountAutomationSerializer
@ -36,7 +44,7 @@ class CheckAccountExecutionViewSet(AutomationExecutionViewSet):
("create", "accounts.add_checkaccountexecution"), ("create", "accounts.add_checkaccountexecution"),
("report", "accounts.view_checkaccountsexecution"), ("report", "accounts.view_checkaccountsexecution"),
) )
ordering = ('-date_created',) ordering = ("-date_created",)
tp = AutomationTypes.check_account tp = AutomationTypes.check_account
def get_queryset(self): def get_queryset(self):
@ -47,61 +55,86 @@ class CheckAccountExecutionViewSet(AutomationExecutionViewSet):
class AccountRiskViewSet(OrgBulkModelViewSet): class AccountRiskViewSet(OrgBulkModelViewSet):
model = AccountRisk model = AccountRisk
search_fields = ('username', 'asset') search_fields = ("username", "asset")
filterset_fields = ('risk', 'status', 'asset') filterset_fields = ("risk", "status", "asset")
serializer_classes = { serializer_classes = {
'default': serializers.AccountRiskSerializer, "default": serializers.AccountRiskSerializer,
'assets': serializers.AssetRiskSerializer, "assets": serializers.AssetRiskSerializer,
'handle': serializers.HandleRiskSerializer "handle": serializers.HandleRiskSerializer,
} }
ordering_fields = ( ordering_fields = ("asset", "risk", "status", "username", "date_created")
'asset', 'risk', 'status', 'username', 'date_created' ordering = ("status", "asset", "date_created")
)
ordering = ('-asset', 'date_created')
rbac_perms = { rbac_perms = {
'sync_accounts': 'assets.add_accountrisk', "sync_accounts": "assets.add_accountrisk",
'assets': 'accounts.view_accountrisk', "assets": "accounts.view_accountrisk",
'handle': 'accounts.change_accountrisk' "handle": "accounts.change_accountrisk",
} }
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
raise MethodNotAllowed('PUT') raise MethodNotAllowed("PUT")
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
raise MethodNotAllowed('POST') raise MethodNotAllowed("POST")
@action(methods=['get'], detail=False, url_path='assets') @action(methods=["get"], detail=False, url_path="assets")
def assets(self, request, *args, **kwargs): def assets(self, request, *args, **kwargs):
annotations = { annotations = {
f'{risk[0]}_count': Count('id', filter=Q(risk=risk[0])) f"{risk[0]}_count": Count("id", filter=Q(risk=risk[0]))
for risk in RiskChoice.choices for risk in RiskChoice.choices
} }
queryset = ( queryset = (
AccountRisk.objects AccountRisk.objects.select_related(
.select_related('asset', 'asset__platform') # 使用 select_related 来优化 asset 和 asset__platform 的查询 "asset", "asset__platform"
.values('asset__id', 'asset__name', 'asset__address', 'asset__platform__name') # 添加需要的字段 ) # 使用 select_related 来优化 asset 和 asset__platform 的查询
.annotate(risk_total=Count('id')) # 计算风险总数 .values(
"asset__id", "asset__name", "asset__address", "asset__platform__name"
) # 添加需要的字段
.annotate(risk_total=Count("id")) # 计算风险总数
.annotate(**annotations) # 使用上面定义的 annotations 进行计数 .annotate(**annotations) # 使用上面定义的 annotations 进行计数
) )
return self.get_paginated_response_from_queryset(queryset) return self.get_paginated_response_from_queryset(queryset)
@action(methods=['post'], detail=False, url_path='handle') @action(methods=["post"], detail=False, url_path="handle")
def handle(self, request, *args, **kwargs): def handle(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) s = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) s.is_valid(raise_exception=True)
asset, username, act, risk = itemgetter('asset', 'username', 'action', 'risk')(serializer.validated_data) asset, username, act, risk = many_get(s.validated_data, ("asset", "username", "action", "risk"))
handler = RiskHandler(asset=asset, username=username) handler = RiskHandler(asset=asset, username=username, request=self.request)
data = handler.handle(act, risk) data = handler.handle(act, risk)
if not data: if not data:
data = {'message': 'Success'} data = {"message": "Success"}
return Response(data) return Response(data)
class CheckAccountEngineViewSet(JMSModelViewSet): class CheckAccountEngineViewSet(JMSModelViewSet):
search_fields = ('name',) search_fields = ("name",)
serializer_class = serializers.CheckAccountEngineSerializer serializer_class = serializers.CheckAccountEngineSerializer
def get_queryset(self): @staticmethod
return CheckAccountEngine.objects.all() def init_if_need():
data = [
{
"id": "00000000-0000-0000-0000-000000000001",
"slug": "check_gathered_account",
"name": "检查发现的账号",
"comment": "基于自动发现的账号结果进行检查分析,检查 用户组、公钥、sudoers 等信息",
},
{
"id": "00000000-0000-0000-0000-000000000002",
"slug": "check_account_secret",
"name": "检查账号密码强弱",
"comment": "基于账号密码的安全性进行检查分析, 检查密码强度、泄露等信息",
},
]
model_cls = CheckAccountEngine
if model_cls.objects.all().count() == 2:
return
for item in data:
model_cls.objects.create(**item)
def get_queryset(self):
self.init_if_need()
return CheckAccountEngine.objects.all()

View File

@ -80,7 +80,7 @@ class GatheredAccountViewSet(OrgBulkModelViewSet):
asset_id = request.data.get("asset_id") asset_id = request.data.get("asset_id")
username = request.data.get("username") username = request.data.get("username")
asset = get_object_or_404(Asset, pk=asset_id) asset = get_object_or_404(Asset, pk=asset_id)
handler = RiskHandler(asset, username) handler = RiskHandler(asset, username, request=self.request)
handler.handle_delete_remote() handler.handle_delete_remote()
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)

View File

@ -90,14 +90,14 @@ def check_account_secrets(accounts, assets):
origin_risk = origin_risks_dict.get(key) origin_risk = origin_risks_dict.get(key)
if origin_risk: if origin_risk:
origin_risk.details.append({"datetime": now}) origin_risk.details.append({"datetime": now, 'type': 'refind'})
update_risk(origin_risk) update_risk(origin_risk)
else: else:
create_risk({ create_risk({
"asset": d["account"].asset, "asset": d["account"].asset,
"username": d["account"].username, "username": d["account"].username,
"risk": d["risk"], "risk": d["risk"],
"details": [{"datetime": now}], "details": [{"datetime": now, 'type': 'init'}],
}) })
return summary, result return summary, result

View File

@ -144,20 +144,21 @@ class AnalyseAccountRisk:
def _update_risk(self, account): def _update_risk(self, account):
return account return account
def analyse_risk(self, asset, ori_account, d): def analyse_risk(self, asset, ori_account, d, sys_found):
if not self.check_risk: if not self.check_risk:
return return
basic = {"asset": asset, "username": d["username"]} basic = {"asset": asset, "username": d["username"]}
if ori_account: if ori_account:
self._analyse_item_changed(ori_account, d) self._analyse_item_changed(ori_account, d)
else: elif not sys_found:
self._create_risk( self._create_risk(
dict( dict(
**basic, risk="new_found", details=[{"datetime": self.now.isoformat()}] **basic,
risk="new_found",
details=[{"datetime": self.now.isoformat()}],
) )
) )
self._analyse_datetime_changed(ori_account, d, asset, d["username"]) self._analyse_datetime_changed(ori_account, d, asset, d["username"])
@ -227,9 +228,8 @@ class GatherAccountsManager(AccountBasePlaybookManager):
for asset_id, username in accounts: for asset_id, username in accounts:
self.ori_asset_usernames[str(asset_id)].add(username) self.ori_asset_usernames[str(asset_id)].add(username)
ga_accounts = ( ga_accounts = GatheredAccount.objects.filter(asset__in=assets).prefetch_related(
GatheredAccount.objects.filter(asset__in=assets) "asset"
.prefetch_related("asset")
) )
for account in ga_accounts: for account in ga_accounts:
self.ori_gathered_usernames[str(account.asset_id)].add(account.username) self.ori_gathered_usernames[str(account.asset_id)].add(account.username)
@ -345,6 +345,7 @@ class GatherAccountsManager(AccountBasePlaybookManager):
risk_analyser = AnalyseAccountRisk(self.check_risk) risk_analyser = AnalyseAccountRisk(self.check_risk)
for asset, accounts_data in self.asset_account_info.items(): for asset, accounts_data in self.asset_account_info.items():
ori_users = self.ori_asset_usernames[str(asset.id)]
with tmp_to_org(asset.org_id): with tmp_to_org(asset.org_id):
gathered_accounts = [] gathered_accounts = []
for d in accounts_data: for d in accounts_data:
@ -357,7 +358,8 @@ class GatherAccountsManager(AccountBasePlaybookManager):
self.create_gathered_account(d) self.create_gathered_account(d)
else: else:
self.update_gathered_account(ori_account, d) self.update_gathered_account(ori_account, d)
risk_analyser.analyse_risk(asset, ori_account, d) ori_found = username in ori_users
risk_analyser.analyse_risk(asset, ori_account, d, ori_found)
self.create_gathered_account.finish() self.create_gathered_account.finish()
self.update_gathered_account.finish() self.update_gathered_account.finish()

View File

@ -5,7 +5,7 @@ from copy import deepcopy
from django.db.models import QuerySet from django.db.models import QuerySet
from accounts.const import AutomationTypes from accounts.const import AutomationTypes
from accounts.models import Account from accounts.models import Account, GatheredAccount, AccountRisk
from common.utils import get_logger from common.utils import get_logger
from ..base.manager import AccountBasePlaybookManager from ..base.manager import AccountBasePlaybookManager
@ -13,59 +13,72 @@ logger = get_logger(__name__)
class RemoveAccountManager(AccountBasePlaybookManager): class RemoveAccountManager(AccountBasePlaybookManager):
super_accounts = ['root', 'administrator'] super_accounts = ["root", "administrator"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.host_account_mapper = dict() self.host_account_mapper = dict()
self.host_accounts = defaultdict(list) self.host_accounts = defaultdict(list)
snapshot_account = self.execution.snapshot.get('accounts', []) snapshot_account = self.execution.snapshot.get("accounts", [])
self.snapshot_asset_account_map = defaultdict(list) self.snapshot_asset_account_map = defaultdict(list)
for account in snapshot_account: for account in snapshot_account:
self.snapshot_asset_account_map[str(account['asset'])].append(account) self.snapshot_asset_account_map[str(account["asset"])].append(account)
def prepare_runtime_dir(self): def prepare_runtime_dir(self):
path = super().prepare_runtime_dir() path = super().prepare_runtime_dir()
ansible_config_path = os.path.join(path, 'ansible.cfg') ansible_config_path = os.path.join(path, "ansible.cfg")
with open(ansible_config_path, 'w') as f: with open(ansible_config_path, "w") as f:
f.write('[ssh_connection]\n') f.write("[ssh_connection]\n")
f.write('ssh_args = -o ControlMaster=no -o ControlPersist=no\n') f.write("ssh_args = -o ControlMaster=no -o ControlPersist=no\n")
return path return path
@classmethod @classmethod
def method_type(cls): def method_type(cls):
return AutomationTypes.remove_account return AutomationTypes.remove_account
def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs): def host_callback(
if host.get('error'): self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs
):
if host.get("error"):
return host return host
inventory_hosts = [] inventory_hosts = []
accounts_to_remove = self.snapshot_asset_account_map.get(str(asset.id), []) accounts_to_remove = self.snapshot_asset_account_map.get(str(asset.id), [])
for account in accounts_to_remove: for account in accounts_to_remove:
username = account.get('username') username = account.get("username")
if not username or username.lower() in self.super_accounts: if not username or username.lower() in self.super_accounts:
print("Super account can not be remove: ", username) print("Super account can not be remove: ", username)
continue continue
h = deepcopy(host) h = deepcopy(host)
h['name'] += '(' + username + ')' h["name"] += "(" + username + ")"
self.host_account_mapper[h['name']] = account self.host_account_mapper[h["name"]] = account
h['account'] = {'username': username} h["account"] = {"username": username}
inventory_hosts.append(h) inventory_hosts.append(h)
return inventory_hosts return inventory_hosts
def on_host_success(self, host, result): def on_host_success(self, host, result):
tuple_asset_gather_account = self.host_account_mapper.get(host) super().on_host_success(host, result)
if not tuple_asset_gather_account: account = self.host_account_mapper.get(host)
if not account:
return return
asset, gather_account = tuple_asset_gather_account
try: try:
Account.objects.filter( Account.objects.filter(
asset_id=asset.id, asset_id=account["asset"], username=account["username"]
username=gather_account.username
).delete() ).delete()
gather_account.delete() GatheredAccount.objects.filter(
asset_id=account["asset"], username=account["username"]
).delete()
risk = AccountRisk.objects.filter(
asset_id=account["asset"],
username=account["username"],
risk__in=["new_found"],
)
print("Account removed: ", account)
except Exception as e: except Exception as e:
print(f'\033[31m Delete account {gather_account.username} failed: {e} \033[0m\n') logger.error(
f"Failed to delete account {account['username']} on asset {account['asset']}: {e}"
)

View File

@ -3,6 +3,7 @@ from itertools import islice
from django.db import models from django.db import models
from django.db.models import TextChoices from django.db.models import TextChoices
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.utils import timezone
from common.const import ConfirmOrIgnore from common.const import ConfirmOrIgnore
from common.db.models import JMSBaseModel from common.db.models import JMSBaseModel
@ -68,6 +69,15 @@ class AccountRisk(JMSOrgBaseModel):
def __str__(self): def __str__(self):
return f"{self.username}@{self.asset} - {self.risk}" return f"{self.username}@{self.asset} - {self.risk}"
def set_status(self, status, user):
self.status = status
self.details.append({'date': timezone.now().isoformat(), 'message': f'{user.username} set status to {status}'})
self.save()
def update_details(self, message, user):
self.details.append({'date': timezone.now().isoformat(), 'message': f'{user.username} {message}'})
self.save(update_fields=['details'])
@classmethod @classmethod
def gen_fake_data(cls, count=1000, batch_size=50): def gen_fake_data(cls, count=1000, batch_size=50):
from assets.models import Asset from assets.models import Asset

View File

@ -1,6 +1,10 @@
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from common.const import ConfirmOrIgnore
from accounts.models import GatheredAccount, AccountRisk, SecretType, AutomationExecution from accounts.models import GatheredAccount, AccountRisk, SecretType, AutomationExecution
from django.utils import timezone
from common.const import ConfirmOrIgnore
TYPE_CHOICES = [ TYPE_CHOICES = [
("ignore", _("Ignore")), ("ignore", _("Ignore")),
@ -14,21 +18,54 @@ TYPE_CHOICES = [
class RiskHandler: class RiskHandler:
def __init__(self, asset, username): def __init__(self, asset, username, request=None, risk=''):
self.asset = asset self.asset = asset
self.username = username self.username = username
self.request = request
self.risk = risk
def handle(self, tp, risk=""): def handle(self, tp, risk=''):
self.risk = risk
attr = f"handle_{tp}" attr = f"handle_{tp}"
if hasattr(self, attr): if hasattr(self, attr):
return getattr(self, attr)(risk=risk) ret = getattr(self, attr)()
self.update_risk_if_need(tp)
return ret
else: else:
raise ValueError(f"Invalid risk type: {tp}") raise ValueError(f"Invalid risk type: {tp}")
def handle_ignore(self, risk=""): def update_risk_if_need(self, tp):
r = self.get_risk()
if not r:
return
status = ConfirmOrIgnore.ignored if tp == 'ignore' else ConfirmOrIgnore.confirmed
r.details.append({
**self.process_detail,
'action': tp, 'status': status
})
r.status = status
r.save()
def get_risk(self):
r = AccountRisk.objects.filter(asset=self.asset, username=self.username)
if self.risk:
r = r.filter(risk=self.risk)
return r.first()
def handle_ignore(self):
pass pass
def handle_add_account(self, risk=""): def handle_review(self):
pass
@property
def process_detail(self):
return {
"datetime": timezone.now().isoformat(), "type": "process",
"processor": str(self.request.user)
}
def handle_add_account(self):
data = { data = {
"username": self.username, "username": self.username,
"name": self.username, "name": self.username,
@ -37,18 +74,14 @@ class RiskHandler:
} }
self.asset.accounts.get_or_create(defaults=data, username=self.username) self.asset.accounts.get_or_create(defaults=data, username=self.username)
GatheredAccount.objects.filter(asset=self.asset, username=self.username).update( GatheredAccount.objects.filter(asset=self.asset, username=self.username).update(
present=True, status="confirmed" present=True, status=ConfirmOrIgnore.confirmed
)
(
AccountRisk.objects.filter(asset=self.asset, username=self.username)
.filter(risk__in=["new_found"])
.update(status="confirmed")
) )
self.risk = 'new_found'
def handle_disable_remote(self, risk=""): def handle_disable_remote(self):
pass pass
def handle_delete_remote(self, risk=""): def handle_delete_remote(self):
asset = self.asset asset = self.asset
execution = AutomationExecution() execution = AutomationExecution()
execution.snapshot = { execution.snapshot = {
@ -59,13 +92,13 @@ class RiskHandler:
} }
execution.save() execution.save()
execution.start() execution.start()
return execution return execution.summary
def handle_delete_both(self, risk=""): def handle_delete_both(self):
pass pass
def handle_change_password_add(self, risk=""): def handle_change_password_add(self):
pass pass
def handle_change_password(self, risk=""): def handle_change_password(self):
pass pass

View File

@ -11,6 +11,7 @@ from accounts.models import (
CheckAccountEngine, CheckAccountEngine,
) )
from assets.models import Asset from assets.models import Asset
from common.const import ConfirmOrIgnore
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from common.utils import get_logger from common.utils import get_logger
from .base import BaseAutomationSerializer from .base import BaseAutomationSerializer
@ -34,6 +35,9 @@ class AccountRiskSerializer(serializers.ModelSerializer):
risk = LabeledChoiceField( risk = LabeledChoiceField(
choices=RiskChoice.choices, required=False, read_only=True, label=_("Risk") choices=RiskChoice.choices, required=False, read_only=True, label=_("Risk")
) )
status = LabeledChoiceField(
choices=ConfirmOrIgnore.choices, required=False, label=_("Status")
)
class Meta: class Meta:
model = AccountRisk model = AccountRisk

View File

@ -15,6 +15,7 @@ from premailer import transform
from sshtunnel import SSHTunnelForwarder from sshtunnel import SSHTunnelForwarder
from assets.automations.methods import platform_automation_methods from assets.automations.methods import platform_automation_methods
from common.const import Status
from common.db.utils import safe_db_connection from common.db.utils import safe_db_connection
from common.tasks import send_mail_async from common.tasks import send_mail_async
from common.utils import get_logger, lazyproperty, is_openssh_format_key, ssh_pubkey_gen from common.utils import get_logger, lazyproperty, is_openssh_format_key, ssh_pubkey_gen
@ -97,13 +98,15 @@ class BaseManager:
self.summary = defaultdict(int) self.summary = defaultdict(int)
self.result = defaultdict(list) self.result = defaultdict(list)
self.duration = 0 self.duration = 0
self.status = 'success'
def get_assets_group_by_platform(self): def get_assets_group_by_platform(self):
return self.execution.all_assets_group_by_platform() return self.execution.all_assets_group_by_platform()
def pre_run(self): def pre_run(self):
self.execution.date_start = timezone.now() self.execution.date_start = timezone.now()
self.execution.save(update_fields=["date_start"]) self.execution.status = Status.running
self.execution.save(update_fields=["date_start", "status"])
def update_execution(self): def update_execution(self):
self.duration = int(time.time() - self.time_start) self.duration = int(time.time() - self.time_start)
@ -111,7 +114,7 @@ class BaseManager:
self.execution.duration = self.duration self.execution.duration = self.duration
self.execution.summary = self.summary self.execution.summary = self.summary
self.execution.result = self.result self.execution.result = self.result
self.execution.status = "success" self.execution.status = self.status
with safe_db_connection(): with safe_db_connection():
self.execution.save() self.execution.save()
@ -161,7 +164,11 @@ class BaseManager:
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
self.pre_run() self.pre_run()
try:
self.do_run(*args, **kwargs) self.do_run(*args, **kwargs)
except:
self.status = 'error'
finally:
self.post_run() self.post_run()
def do_run(self, *args, **kwargs): def do_run(self, *args, **kwargs):
@ -365,6 +372,7 @@ class BasePlaybookManager(PlaybookPrepareMixin, BaseManager):
def __init__(self, execution): def __init__(self, execution):
super().__init__(execution) super().__init__(execution)
self.params = execution.snapshot.get("params", {}) self.params = execution.snapshot.get("params", {})
self.host_success_callbacks = []
def get_assets_group_by_platform(self): def get_assets_group_by_platform(self):
return self.execution.all_assets_group_by_platform() return self.execution.all_assets_group_by_platform()
@ -451,6 +459,9 @@ class BasePlaybookManager(PlaybookPrepareMixin, BaseManager):
self.summary["ok_assets"] += 1 self.summary["ok_assets"] += 1
self.result["ok_assets"].append(host) self.result["ok_assets"].append(host)
for cb in self.host_success_callbacks:
cb(host, result)
def on_host_error(self, host, error, result): def on_host_error(self, host, error, result):
self.summary["fail_assets"] += 1 self.summary["fail_assets"] += 1
self.result["fail_assets"].append((host, str(error))) self.result["fail_assets"].append((host, str(error)))
@ -464,6 +475,11 @@ class BasePlaybookManager(PlaybookPrepareMixin, BaseManager):
detail = result.get("failures", "") or result.get("dark", "") detail = result.get("failures", "") or result.get("dark", "")
self.on_host_error(host, error, detail) self.on_host_error(host, error, detail)
def post_run(self):
if self.summary['fail_assets']:
self.status = 'failed'
super().post_run()
def on_runner_success(self, runner, cb): def on_runner_success(self, runner, cb):
summary = cb.summary summary = cb.summary
for state, hosts in summary.items(): for state, hosts in summary.items():

View File

@ -0,0 +1,31 @@
# Generated by Django 4.1.13 on 2024-12-02 11:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("assets", "0010_alter_automationexecution_duration"),
]
operations = [
migrations.AlterField(
model_name="automationexecution",
name="status",
field=models.CharField(
choices=[
("ready", "Ready"),
("pending", "Pending"),
("running", "Running"),
("success", "Success"),
("failed", "Failed"),
("error", "Error"),
("canceled", "Canceled"),
],
default="pending",
max_length=16,
verbose_name="Status",
),
),
]

View File

@ -7,7 +7,7 @@ from django.utils.translation import gettext_lazy as _
from assets.models.asset import Asset from assets.models.asset import Asset
from assets.models.node import Node from assets.models.node import Node
from assets.tasks import execute_asset_automation_task from assets.tasks import execute_asset_automation_task
from common.const.choices import Trigger from common.const.choices import Trigger, Status
from common.db.fields import EncryptJsonDictTextField from common.db.fields import EncryptJsonDictTextField
from ops.mixin import PeriodTaskModelMixin from ops.mixin import PeriodTaskModelMixin
from orgs.mixins.models import OrgModelMixin, JMSOrgBaseModel from orgs.mixins.models import OrgModelMixin, JMSOrgBaseModel
@ -16,9 +16,11 @@ from users.models import User
class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel): class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
accounts = models.JSONField(default=list, verbose_name=_("Accounts")) accounts = models.JSONField(default=list, verbose_name=_("Accounts"))
nodes = models.ManyToManyField('assets.Node', blank=True, verbose_name=_("Node")) nodes = models.ManyToManyField("assets.Node", blank=True, verbose_name=_("Node"))
assets = models.ManyToManyField('assets.Asset', blank=True, verbose_name=_("Assets")) assets = models.ManyToManyField(
type = models.CharField(max_length=16, verbose_name=_('Type')) "assets.Asset", blank=True, verbose_name=_("Assets")
)
type = models.CharField(max_length=16, verbose_name=_("Type"))
is_active = models.BooleanField(default=True, verbose_name=_("Is active")) is_active = models.BooleanField(default=True, verbose_name=_("Is active"))
params = models.JSONField(default=dict, verbose_name=_("Parameters")) params = models.JSONField(default=dict, verbose_name=_("Parameters"))
@ -26,10 +28,10 @@ class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
raise NotImplementedError raise NotImplementedError
def __str__(self): def __str__(self):
return self.name + '@' + str(self.created_by) return self.name + "@" + str(self.created_by)
class Meta: class Meta:
unique_together = [('org_id', 'name', 'type')] unique_together = [("org_id", "name", "type")]
verbose_name = _("Automation task") verbose_name = _("Automation task")
@classmethod @classmethod
@ -43,13 +45,13 @@ class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
def get_all_assets(self): def get_all_assets(self):
nodes = self.nodes.all() nodes = self.nodes.all()
node_asset_ids = Node.get_nodes_all_assets(*nodes).values_list('id', flat=True) node_asset_ids = Node.get_nodes_all_assets(*nodes).values_list("id", flat=True)
direct_asset_ids = self.assets.all().values_list('id', flat=True) direct_asset_ids = self.assets.all().values_list("id", flat=True)
asset_ids = set(list(direct_asset_ids) + list(node_asset_ids)) asset_ids = set(list(direct_asset_ids) + list(node_asset_ids))
return Asset.objects.filter(id__in=asset_ids) return Asset.objects.filter(id__in=asset_ids)
def all_assets_group_by_platform(self): def all_assets_group_by_platform(self):
assets = self.get_all_assets().prefetch_related('platform') assets = self.get_all_assets().prefetch_related("platform")
return assets.group_by_platform() return assets.group_by_platform()
@property @property
@ -64,17 +66,17 @@ class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
return name, task, args, kwargs return name, task, args, kwargs
def get_many_to_many_ids(self, field: str): def get_many_to_many_ids(self, field: str):
return [str(i) for i in getattr(self, field).all().values_list('id', flat=True)] return [str(i) for i in getattr(self, field).all().values_list("id", flat=True)]
def to_attr_json(self): def to_attr_json(self):
return { return {
'name': self.name, "name": self.name,
'type': self.type, "type": self.type,
'comment': self.comment, "comment": self.comment,
'accounts': self.accounts, "accounts": self.accounts,
'org_id': str(self.org_id), "org_id": str(self.org_id),
'nodes': self.get_many_to_many_ids('nodes'), "nodes": self.get_many_to_many_ids("nodes"),
'assets': self.get_many_to_many_ids('assets'), "assets": self.get_many_to_many_ids("assets"),
} }
@property @property
@ -96,7 +98,9 @@ class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
eid = str(uuid.uuid4()) eid = str(uuid.uuid4())
execution = self.execution_model.objects.create( execution = self.execution_model.objects.create(
id=eid, trigger=trigger, automation=self, id=eid,
trigger=trigger,
automation=self,
snapshot=self.to_attr_json(), snapshot=self.to_attr_json(),
) )
return execution.start() return execution.start()
@ -111,37 +115,60 @@ class AssetBaseAutomation(BaseAutomation):
class AutomationExecution(OrgModelMixin): class AutomationExecution(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True) id = models.UUIDField(default=uuid.uuid4, primary_key=True)
automation = models.ForeignKey( automation = models.ForeignKey(
'BaseAutomation', related_name='executions', on_delete=models.CASCADE, "BaseAutomation",
verbose_name=_('Automation task'), null=True related_name="executions",
on_delete=models.CASCADE,
verbose_name=_("Automation task"),
null=True,
)
# pending, running, success, failed, terminated
status = models.CharField(
max_length=16, default=Status.pending, choices=Status.choices, verbose_name=_("Status")
)
date_created = models.DateTimeField(
auto_now_add=True, verbose_name=_("Date created")
)
date_start = models.DateTimeField(
null=True, verbose_name=_("Date start"), db_index=True
) )
status = models.CharField(max_length=16, default='pending', verbose_name=_('Status'))
date_created = models.DateTimeField(auto_now_add=True, verbose_name=_('Date created'))
date_start = models.DateTimeField(null=True, verbose_name=_('Date start'), db_index=True)
date_finished = models.DateTimeField(null=True, verbose_name=_("Date finished")) date_finished = models.DateTimeField(null=True, verbose_name=_("Date finished"))
duration = models.IntegerField(default=0, verbose_name=_('Duration')) duration = models.IntegerField(default=0, verbose_name=_("Duration"))
snapshot = EncryptJsonDictTextField( snapshot = EncryptJsonDictTextField(
default=dict, blank=True, null=True, verbose_name=_('Automation snapshot') default=dict, blank=True, null=True, verbose_name=_("Automation snapshot")
) )
trigger = models.CharField( trigger = models.CharField(
max_length=128, default=Trigger.manual, choices=Trigger.choices, max_length=128,
verbose_name=_('Trigger mode') default=Trigger.manual,
choices=Trigger.choices,
verbose_name=_("Trigger mode"),
) )
summary = models.JSONField(default=dict, verbose_name=_('Summary')) summary = models.JSONField(default=dict, verbose_name=_("Summary"))
result = models.JSONField(default=dict, verbose_name=_('Result')) result = models.JSONField(default=dict, verbose_name=_("Result"))
class Meta: class Meta:
ordering = ('org_id', '-date_start',) ordering = (
verbose_name = _('Automation task execution') "org_id",
"-date_start",
)
verbose_name = _("Automation task execution")
@property
def is_finished(self):
return bool(self.date_finished)
@property
def is_success(self):
return self.status == Status.success
@property @property
def manager_type(self): def manager_type(self):
return self.snapshot['type'] return self.snapshot["type"]
def get_all_asset_ids(self): def get_all_asset_ids(self):
node_ids = self.snapshot.get('nodes', []) node_ids = self.snapshot.get("nodes", [])
asset_ids = self.snapshot.get('assets', []) asset_ids = self.snapshot.get("assets", [])
nodes = Node.objects.filter(id__in=node_ids) nodes = Node.objects.filter(id__in=node_ids)
node_asset_ids = Node.get_nodes_all_assets(*nodes).values_list('id', flat=True) node_asset_ids = Node.get_nodes_all_assets(*nodes).values_list("id", flat=True)
asset_ids = set(list(asset_ids) + list(node_asset_ids)) asset_ids = set(list(asset_ids) + list(node_asset_ids))
return asset_ids return asset_ids
@ -150,12 +177,12 @@ class AutomationExecution(OrgModelMixin):
return Asset.objects.filter(id__in=asset_ids) return Asset.objects.filter(id__in=asset_ids)
def all_assets_group_by_platform(self): def all_assets_group_by_platform(self):
assets = self.get_all_assets().prefetch_related('platform') assets = self.get_all_assets().prefetch_related("platform")
return assets.group_by_platform() return assets.group_by_platform()
@property @property
def recipients(self): def recipients(self):
recipients = self.snapshot.get('recipients') recipients = self.snapshot.get("recipients")
if not recipients: if not recipients:
return [] return []
users = User.objects.filter(id__in=recipients) users = User.objects.filter(id__in=recipients)
@ -164,6 +191,7 @@ class AutomationExecution(OrgModelMixin):
@property @property
def manager(self): def manager(self):
from assets.automations.endpoint import ExecutionManager from assets.automations.endpoint import ExecutionManager
return ExecutionManager(execution=self) return ExecutionManager(execution=self)
def start(self): def start(self):

View File

@ -433,3 +433,8 @@ def convert_html_to_markdown(html_str):
markdown = markdown.replace('\n\n', '\n') markdown = markdown.replace('\n\n', '\n')
markdown = markdown.replace('\n ', '\n') markdown = markdown.replace('\n ', '\n')
return markdown return markdown
def many_get(d, keys, default=None):
res = [d.get(key, default) for key in keys]
return res

View File

@ -1243,7 +1243,7 @@
"TestLdapLoginTitle": "Test ldap user login", "TestLdapLoginTitle": "Test ldap user login",
"TestNodeAssetConnectivity": "Test assets connectivity of node", "TestNodeAssetConnectivity": "Test assets connectivity of node",
"TestPortErrorMsg": "Port error, please re-enter", "TestPortErrorMsg": "Port error, please re-enter",
"TestSelected": "Test selected", "TestSelected": "Verify selected",
"TestSuccessMsg": "Test succeeded", "TestSuccessMsg": "Test succeeded",
"Thursday": "Thu", "Thursday": "Thu",
"Ticket": "Ticket", "Ticket": "Ticket",