diff --git a/apps/terminal/api/command.py b/apps/terminal/api/command.py index f969e24d7..d7868e4a0 100644 --- a/apps/terminal/api/command.py +++ b/apps/terminal/api/command.py @@ -11,6 +11,7 @@ from rest_framework.response import Response from rest_framework.decorators import action from django.template import loader +from common.http import is_true from terminal.models import CommandStorage, Command from terminal.filters import CommandFilter from orgs.utils import current_org @@ -140,7 +141,21 @@ class CommandViewSet(viewsets.ModelViewSet): if session_id and not command_storage_id: # 会话里的命令列表肯定会提供 session_id,这里防止 merge 的时候取全量的数据 return self.merge_all_storage_list(request, *args, **kwargs) - return super().list(request, *args, **kwargs) + + queryset = self.filter_queryset(self.get_queryset()) + + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + query_all = self.request.query_params.get('all', False) + if is_true(query_all): + # 适配像 ES 这种没有指定分页只返回少量数据的情况 + queryset = queryset[:] + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) def get_queryset(self): command_storage_id = self.request.query_params.get('command_storage_id') diff --git a/apps/terminal/backends/command/es.py b/apps/terminal/backends/command/es.py index fc0f247f4..d8197391d 100644 --- a/apps/terminal/backends/command/es.py +++ b/apps/terminal/backends/command/es.py @@ -10,6 +10,7 @@ import inspect from django.db.models import QuerySet as DJQuerySet from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk +from elasticsearch.exceptions import RequestError from common.utils.common import lazyproperty from common.utils import get_logger @@ -31,6 +32,15 @@ class CommandStore(): kwargs['verify_certs'] = None self.es = Elasticsearch(hosts=hosts, max_retries=0, **kwargs) + def pre_use_check(self): + self._ensure_index_exists() + + def _ensure_index_exists(self): + try: + self.es.indices.create(self.index) + except RequestError: + pass + @staticmethod def make_data(command): data = dict( @@ -234,6 +244,7 @@ class QuerySet(DJQuerySet): uqs = QuerySet(self._command_store_config) uqs._method_calls = self._method_calls.copy() uqs._slice = self._slice + uqs.model = self.model return uqs def count(self, limit_to_max_result_window=True): diff --git a/apps/terminal/models/storage.py b/apps/terminal/models/storage.py index 4826e2eef..883e5f67a 100644 --- a/apps/terminal/models/storage.py +++ b/apps/terminal/models/storage.py @@ -76,6 +76,15 @@ class CommandStorage(CommonModelMixin): qs.model = Command return qs + def save(self, force_insert=False, force_update=False, using=None, + update_fields=None): + super().save() + + if self.type in TYPE_ENGINE_MAPPING: + engine_mod = import_module(TYPE_ENGINE_MAPPING[self.type]) + backend = engine_mod.CommandStore(self.config) + backend.pre_use_check() + class ReplayStorage(CommonModelMixin): name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)