# -*- coding: utf-8 -*- # from django.db.models import Q, Count from django.http import HttpResponse from django.shortcuts import get_object_or_404 from django.utils import timezone from rest_framework.decorators import action from rest_framework.exceptions import MethodNotAllowed from rest_framework.response import Response from accounts import serializers from accounts.const import AutomationTypes from accounts.models import ( CheckAccountAutomation, AccountRisk, RiskChoice, CheckAccountEngine, AutomationExecution, ) from assets.models import Asset from common.api import JMSModelViewSet from common.utils import many_get from orgs.mixins.api import OrgBulkModelViewSet from .base import AutomationExecutionViewSet __all__ = [ "CheckAccountAutomationViewSet", "CheckAccountExecutionViewSet", "AccountRiskViewSet", "CheckAccountEngineViewSet", ] from ...risk_handlers import RiskHandler class CheckAccountAutomationViewSet(OrgBulkModelViewSet): model = CheckAccountAutomation filterset_fields = ("name",) search_fields = filterset_fields serializer_class = serializers.CheckAccountAutomationSerializer class CheckAccountExecutionViewSet(AutomationExecutionViewSet): rbac_perms = ( ("list", "accounts.view_checkaccountexecution"), ("retrieve", "accounts.view_checkaccountsexecution"), ("create", "accounts.add_checkaccountexecution"), ("adhoc", "accounts.add_checkaccountexecution"), ("report", "accounts.view_checkaccountsexecution"), ) ordering = ("-date_created",) tp = AutomationTypes.check_account def get_queryset(self): queryset = super().get_queryset() queryset = queryset.filter(automation__type=self.tp) return queryset @action(methods=["get"], detail=False, url_path="adhoc") def adhoc(self, request, *args, **kwargs): asset_id = request.query_params.get("asset_id") if not asset_id: return Response(status=400, data={"asset_id": "This field is required."}) get_object_or_404(Asset, pk=asset_id) execution = AutomationExecution() execution.snapshot = { "assets": [asset_id], "nodes": [], "type": AutomationTypes.check_account, "engines": ["check_account_secret"], "name": "Check asset risk: {} {}".format(asset_id, timezone.now()), } execution.save() execution.start() report = execution.manager.gen_report() return HttpResponse(report) class AccountRiskViewSet(OrgBulkModelViewSet): model = AccountRisk search_fields = ("username", "asset") filterset_fields = ("risk", "status", "asset") serializer_classes = { "default": serializers.AccountRiskSerializer, "assets": serializers.AssetRiskSerializer, "handle": serializers.HandleRiskSerializer, } ordering_fields = ("asset", "risk", "status", "username", "date_created") ordering = ("status", "asset", "date_created") rbac_perms = { "sync_accounts": "assets.add_accountrisk", "assets": "accounts.view_accountrisk", "handle": "accounts.change_accountrisk", } def update(self, request, *args, **kwargs): raise MethodNotAllowed("PUT") def create(self, request, *args, **kwargs): raise MethodNotAllowed("POST") @action(methods=["get"], detail=False, url_path="assets") def assets(self, request, *args, **kwargs): annotations = { f"{risk[0]}_count": Count("id", filter=Q(risk=risk[0])) for risk in RiskChoice.choices } queryset = ( AccountRisk.objects.select_related( "asset", "asset__platform" ) # 使用 select_related 来优化 asset 和 asset__platform 的查询 .values( "asset__id", "asset__name", "asset__address", "asset__platform__name" ) # 添加需要的字段 .annotate(risk_total=Count("id")) # 计算风险总数 .annotate(**annotations) # 使用上面定义的 annotations 进行计数 ) return self.get_paginated_response_from_queryset(queryset) @action(methods=["post"], detail=False, url_path="handle") def handle(self, request, *args, **kwargs): s = self.get_serializer(data=request.data) s.is_valid(raise_exception=True) asset, username, act, risk = many_get( s.validated_data, ("asset", "username", "action", "risk") ) handler = RiskHandler(asset=asset, username=username, request=self.request) data = handler.handle(act, risk) if not data: data = {"message": "Success"} s = serializers.AccountRiskSerializer(instance=data) return Response(data=s.data) class CheckAccountEngineViewSet(JMSModelViewSet): search_fields = ("name",) serializer_class = serializers.CheckAccountEngineSerializer perm_model = CheckAccountEngine def get_queryset(self): return CheckAccountEngine.get_default_engines() def filter_queryset(self, queryset: list): search = self.request.GET.get('search') if search is not None: queryset = [ item for item in queryset if search in item['name'] ] return queryset