diff --git a/apps/accounts/api/account/account.py b/apps/accounts/api/account/account.py index 24eb47b27..4416f8f9c 100644 --- a/apps/accounts/api/account/account.py +++ b/apps/accounts/api/account/account.py @@ -8,7 +8,7 @@ from rest_framework.status import HTTP_200_OK from accounts import serializers from accounts.const import ChangeSecretRecordStatusChoice -from accounts.filters import AccountFilterSet +from accounts.filters import AccountFilterSet, NodeFilterBackend from accounts.mixins import AccountRecordViewLogMixin from accounts.models import Account, ChangeSecretRecord from assets.models import Asset, Node @@ -31,7 +31,7 @@ __all__ = [ class AccountViewSet(OrgBulkModelViewSet): model = Account search_fields = ('username', 'name', 'asset__name', 'asset__address', 'comment') - extra_filter_backends = [AttrRulesFilterBackend] + extra_filter_backends = [AttrRulesFilterBackend, NodeFilterBackend] filterset_class = AccountFilterSet serializer_classes = { 'default': serializers.AccountSerializer, diff --git a/apps/accounts/api/automations/check_account.py b/apps/accounts/api/automations/check_account.py index 080422bd6..64ff46054 100644 --- a/apps/accounts/api/automations/check_account.py +++ b/apps/accounts/api/automations/check_account.py @@ -30,6 +30,8 @@ __all__ = [ "CheckAccountEngineViewSet", ] +from ...filters import NodeFilterBackend + from ...risk_handlers import RiskHandler @@ -80,7 +82,8 @@ class CheckAccountExecutionViewSet(AutomationExecutionViewSet): class AccountRiskViewSet(OrgBulkModelViewSet): model = AccountRisk search_fields = ("username", "asset") - filterset_fields = ("risk", "status", "asset") + filterset_fields = ("risk", "status", "asset_id") + extra_filter_backends = [NodeFilterBackend] serializer_classes = { "default": serializers.AccountRiskSerializer, "assets": serializers.AssetRiskSerializer, diff --git a/apps/accounts/filters.py b/apps/accounts/filters.py index 716f7394a..90df5d375 100644 --- a/apps/accounts/filters.py +++ b/apps/accounts/filters.py @@ -3,14 +3,41 @@ from django.db.models import Q from django.utils import timezone from django_filters import rest_framework as drf_filters +from rest_framework import filters +from rest_framework.compat import coreapi from assets.models import Node +from assets.utils import get_node_from_request from common.drf.filters import BaseFilterSet from common.utils.timezone import local_zero_hour, local_now from .const.automation import ChangeSecretRecordStatusChoice from .models import Account, GatheredAccount, ChangeSecretRecord, PushSecretRecord, IntegrationApplication +class NodeFilterBackend(filters.BaseFilterBackend): + fields = ['node_id'] + + def get_schema_fields(self, view): + return [ + coreapi.Field( + name=field, location='query', required=False, + type='string', example='', description='', schema=None, + ) + for field in self.fields + ] + + def filter_queryset(self, request, queryset, view): + node = get_node_from_request(request) + if node is None: + return queryset + + node_qs = Node.objects.none() + node_qs |= node.get_all_children(with_self=True) + node_ids = list(node_qs.values_list("id", flat=True)) + queryset = queryset.filter(asset__nodes__in=node_ids) + return queryset + + class AccountFilterSet(BaseFilterSet): ip = drf_filters.CharFilter(field_name="address", lookup_expr="exact") hostname = drf_filters.CharFilter(field_name="name", lookup_expr="exact") @@ -19,8 +46,6 @@ class AccountFilterSet(BaseFilterSet): asset_id = drf_filters.CharFilter(field_name="asset", lookup_expr="exact") asset = drf_filters.CharFilter(field_name="asset", lookup_expr="exact") assets = drf_filters.CharFilter(field_name="asset_id", lookup_expr="exact") - nodes = drf_filters.CharFilter(method="filter_nodes") - node_id = drf_filters.CharFilter(method="filter_nodes") has_secret = drf_filters.BooleanFilter(method="filter_has_secret") platform = drf_filters.CharFilter( field_name="asset__platform_id", lookup_expr="exact" @@ -36,9 +61,7 @@ class AccountFilterSet(BaseFilterSet): latest_updated = drf_filters.BooleanFilter(method="filter_latest") latest_secret_changed = drf_filters.BooleanFilter(method="filter_latest") latest_secret_change_failed = drf_filters.BooleanFilter(method="filter_latest") - risk = drf_filters.CharFilter( - method="filter_risk", - ) + risk = drf_filters.CharFilter(method="filter_risk") integrationapplication = drf_filters.CharFilter(method="filter_integrationapplication") long_time_no_change_secret = drf_filters.BooleanFilter(method="filter_long_time") long_time_no_verified = drf_filters.BooleanFilter(method="filter_long_time") @@ -111,19 +134,6 @@ class AccountFilterSet(BaseFilterSet): queryset = queryset.filter(**kwargs) return queryset - @staticmethod - def filter_nodes(queryset, name, value): - nodes = Node.objects.filter(id=value) - if not nodes: - return queryset - - node_qs = Node.objects.none() - for node in nodes: - node_qs |= node.get_all_children(with_self=True) - node_ids = list(node_qs.values_list("id", flat=True)) - queryset = queryset.filter(asset__nodes__in=node_ids) - return queryset - class Meta: model = Account fields = [