diff --git a/apps/common/drf/exc_handlers.py b/apps/common/drf/exc_handlers.py index 93ff84146..b99a53547 100644 --- a/apps/common/drf/exc_handlers.py +++ b/apps/common/drf/exc_handlers.py @@ -19,7 +19,10 @@ def extract_object_name(exc, index=0): `No User matches the given query.` 提取 `User`,`index=1` """ - (msg, *_) = exc.args + if exc.args: + (msg, *others) = exc.args + else: + return gettext('Object') return gettext(msg.split(sep=' ', maxsplit=index + 1)[index]) diff --git a/apps/common/drf/metadata.py b/apps/common/drf/metadata.py index 04c365d97..cc2903d2f 100644 --- a/apps/common/drf/metadata.py +++ b/apps/common/drf/metadata.py @@ -97,6 +97,8 @@ class SimpleMetadataWithFilters(SimpleMetadata): fields = view.filterset_fields elif hasattr(view, 'get_filterset_fields'): fields = view.get_filterset_fields(request) + elif hasattr(view, 'filterset_class'): + fields = view.filterset_class.Meta.fields if isinstance(fields, dict): fields = list(fields.keys()) diff --git a/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py b/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py index 3a1d49016..d7a5c23dc 100644 --- a/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py +++ b/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py @@ -7,7 +7,7 @@ from common.utils import get_logger from perms.pagination import NodeGrantedAssetPagination, AllGrantedAssetPagination from assets.models import Asset, Node from perms import serializers -from perms.utils.asset.user_permission import UserGrantedAssetsQueryUtils, QuerySetStage +from perms.utils.asset.user_permission import UserGrantedAssetsQueryUtils logger = get_logger(__name__) diff --git a/apps/perms/utils/asset/user_permission.py b/apps/perms/utils/asset/user_permission.py index 14205e12e..0741f42e0 100644 --- a/apps/perms/utils/asset/user_permission.py +++ b/apps/perms/utils/asset/user_permission.py @@ -53,122 +53,6 @@ def get_user_all_asset_perm_ids(user) -> set: return asset_perm_ids -class QuerySetStage: - def __init__(self): - self._prefetch_related = set() - self._only = () - self._filters = [] - self._querysets_and = [] - self._querysets_or = [] - self._order_by = None - self._annotate = [] - self._before_union_merge_funs = set() - self._after_union_merge_funs = set() - - def annotate(self, *args, **kwargs): - self._annotate.append((args, kwargs)) - self._before_union_merge_funs.add(self._merge_annotate) - return self - - def prefetch_related(self, *lookups): - self._prefetch_related.update(lookups) - self._before_union_merge_funs.add(self._merge_prefetch_related) - return self - - def only(self, *fields): - self._only = fields - self._before_union_merge_funs.add(self._merge_only) - return self - - def order_by(self, *field_names): - self._order_by = field_names - self._after_union_merge_funs.add(self._merge_order_by) - return self - - def filter(self, *args, **kwargs): - self._filters.append((args, kwargs)) - self._before_union_merge_funs.add(self._merge_filters) - return self - - def and_with_queryset(self, qs: QuerySet): - assert isinstance(qs, QuerySet), f'Must be `QuerySet`' - self._order_by = qs.query.order_by - self._after_union_merge_funs.add(self._merge_order_by) - self._querysets_and.append(qs.order_by()) - self._before_union_merge_funs.add(self._merge_querysets_and) - return self - - def or_with_queryset(self, qs: QuerySet): - assert isinstance(qs, QuerySet), f'Must be `QuerySet`' - self._order_by = qs.query.order_by - self._after_union_merge_funs.add(self._merge_order_by) - self._querysets_or.append(qs.order_by()) - self._before_union_merge_funs.add(self._merge_querysets_or) - return self - - def merge_multi_before_union(self, *querysets): - ret = [] - for qs in querysets: - qs = self.merge_before_union(qs) - ret.append(qs) - return ret - - def _merge_only(self, qs: QuerySet): - if self._only: - qs = qs.only(*self._only) - return qs - - def _merge_filters(self, qs: QuerySet): - if self._filters: - for args, kwargs in self._filters: - qs = qs.filter(*args, **kwargs) - return qs - - def _merge_querysets_and(self, qs: QuerySet): - if self._querysets_and: - for qs_and in self._querysets_and: - qs &= qs_and - return qs - - def _merge_annotate(self, qs: QuerySet): - if self._annotate: - for args, kwargs in self._annotate: - qs = qs.annotate(*args, **kwargs) - return qs - - def _merge_querysets_or(self, qs: QuerySet): - if self._querysets_or: - for qs_or in self._querysets_or: - qs |= qs_or - return qs - - def _merge_prefetch_related(self, qs: QuerySet): - if self._prefetch_related: - qs = qs.prefetch_related(*self._prefetch_related) - return qs - - def _merge_order_by(self, qs: QuerySet): - if self._order_by is not None: - qs = qs.order_by(*self._order_by) - return qs - - def merge_before_union(self, qs: QuerySet) -> QuerySet: - assert isinstance(qs, QuerySet), f'Must be `QuerySet`' - for fun in self._before_union_merge_funs: - qs = fun(qs) - return qs - - def merge_after_union(self, qs: QuerySet) -> QuerySet: - for fun in self._after_union_merge_funs: - qs = fun(qs) - return qs - - def merge(self, qs: QuerySet) -> QuerySet: - qs = self.merge_before_union(qs) - qs = self.merge_after_union(qs) - return qs - - class UserGrantedTreeRefreshController: key_template = 'perms.user.node_tree.builded_orgs.user_id:{user_id}' diff --git a/apps/terminal/api/command.py b/apps/terminal/api/command.py index cd01df1b3..e75ce491e 100644 --- a/apps/terminal/api/command.py +++ b/apps/terminal/api/command.py @@ -8,11 +8,14 @@ from rest_framework import viewsets from rest_framework import generics from rest_framework.fields import DateTimeField from rest_framework.response import Response -from rest_framework import status +from rest_framework.decorators import action from django.template import loader +from terminal.models import CommandStorage +from terminal.filters import CommandFilter from orgs.utils import current_org from common.permissions import IsOrgAdminOrAppUser, IsOrgAuditor, IsAppUser +from common.const.http import GET from common.utils import get_logger from terminal.utils import send_command_alert_mail from terminal.serializers import InsecureCommandAlertSerializer @@ -89,7 +92,7 @@ class CommandQueryMixin: return date_from_st, date_to_st -class CommandViewSet(CommandQueryMixin, viewsets.ModelViewSet): +class CommandViewSet(viewsets.ModelViewSet): """接受app发送来的command log, 格式如下 { "user": "admin", @@ -103,7 +106,16 @@ class CommandViewSet(CommandQueryMixin, viewsets.ModelViewSet): """ command_store = get_command_storage() + permission_classes = [IsOrgAdminOrAppUser | IsOrgAuditor] serializer_class = SessionCommandSerializer + filterset_class = CommandFilter + ordering_fields = ('timestamp', ) + + def get_queryset(self): + command_storage_id = self.request.query_params.get('command_storage_id') + storage = CommandStorage.objects.get(id=command_storage_id) + qs = storage.get_command_queryset() + return qs def create(self, request, *args, **kwargs): serializer = self.serializer_class(data=request.data, many=True) diff --git a/apps/terminal/api/storage.py b/apps/terminal/api/storage.py index be3002060..b2bd69d28 100644 --- a/apps/terminal/api/storage.py +++ b/apps/terminal/api/storage.py @@ -3,9 +3,15 @@ from rest_framework import viewsets, generics, status from rest_framework.response import Response +from rest_framework.request import Request +from rest_framework.decorators import action from django.utils.translation import ugettext_lazy as _ +from django_filters import utils +from terminal import const +from common.const.http import GET from common.permissions import IsSuperUser +from terminal.filters import CommandStorageFilter, CommandFilter, CommandFilterForStorageTree from ..models import CommandStorage, ReplayStorage from ..serializers import CommandStorageSerializer, ReplayStorageSerializer @@ -30,11 +36,52 @@ class BaseStorageViewSetMixin: class CommandStorageViewSet(BaseStorageViewSetMixin, viewsets.ModelViewSet): - filterset_fields = ('name', 'type',) - search_fields = filterset_fields + search_fields = ('name', 'type',) queryset = CommandStorage.objects.all() serializer_class = CommandStorageSerializer permission_classes = (IsSuperUser,) + filterset_class = CommandStorageFilter + + @action(methods=[GET], detail=False, filterset_class=CommandFilterForStorageTree) + def tree(self, request: Request): + storage_qs = self.get_queryset().exclude(name='null') + storages_with_count = [] + for storage in storage_qs: + command_qs = storage.get_command_queryset() + filterset = CommandFilter( + data=request.query_params, queryset=command_qs, + request=request + ) + if not filterset.is_valid(): + raise utils.translate_validation(filterset.errors) + command_qs = filterset.qs + if storage.type == const.CommandStorageTypeChoices.es: + command_count = command_qs.count(limit_to_max_result_window=False) + else: + command_count = command_qs.count() + storages_with_count.append((storage, command_count)) + + root = { + 'id': 'root', + 'name': _('Command storages'), + 'title': _('Command storages'), + 'pId': '', + 'isParent': True, + 'open': True, + } + + nodes = [ + { + 'id': storage.id, + 'name': f'{storage.name}({storage.type})({command_count})', + 'title': f'{storage.name}({storage.type})', + 'pId': 'root', + 'isParent': False, + 'open': False, + } for storage, command_count in storages_with_count + ] + nodes.append(root) + return Response(data=nodes) class ReplayStorageViewSet(BaseStorageViewSetMixin, viewsets.ModelViewSet): diff --git a/apps/terminal/backends/command/es.py b/apps/terminal/backends/command/es.py index 43bd52c02..1137f6ec8 100644 --- a/apps/terminal/backends/command/es.py +++ b/apps/terminal/backends/command/es.py @@ -1,53 +1,270 @@ # -*- coding: utf-8 -*- # - from datetime import datetime -from jms_storage.es import ESStorage +from functools import reduce, partial +from itertools import groupby +import pytz +from uuid import UUID +import inspect + +from django.db.models import QuerySet as DJQuerySet +from elasticsearch import Elasticsearch +from elasticsearch.helpers import bulk + +from common.utils.common import lazyproperty from common.utils import get_logger -from .base import CommandBase from .models import AbstractSessionCommand logger = get_logger(__file__) -class CommandStore(ESStorage, CommandBase): - def __init__(self, params): - super().__init__(params) +class CommandStore(): + def __init__(self, config): + hosts = config.get("HOSTS") + kwargs = config.get("OTHER", {}) + self.index = config.get("INDEX") or 'jumpserver' + self.doc_type = config.get("DOC_TYPE") or 'command_store' + self.es = Elasticsearch(hosts=hosts, **kwargs) - def filter(self, date_from=None, date_to=None, - user=None, asset=None, system_user=None, - input=None, session=None, risk_level=None, org_id=None): + @staticmethod + def make_data(command): + data = dict( + user=command["user"], asset=command["asset"], + system_user=command["system_user"], input=command["input"], + output=command["output"], risk_level=command["risk_level"], + session=command["session"], timestamp=command["timestamp"], + org_id=command["org_id"] + ) + data["date"] = datetime.fromtimestamp(command['timestamp'], tz=pytz.UTC) + return data - if date_from is not None: - if isinstance(date_from, float): - date_from = datetime.fromtimestamp(date_from) - if date_to is not None: - if isinstance(date_to, float): - date_to = datetime.fromtimestamp(date_to) - - try: - data = super().filter(date_from=date_from, date_to=date_to, - user=user, asset=asset, system_user=system_user, - input=input, session=session, - risk_level=risk_level, org_id=org_id) - except Exception as e: - logger.error(e, exc_info=True) - return [] - else: - return AbstractSessionCommand.from_multi_dict( - [item["_source"] for item in data["hits"] if item] + def bulk_save(self, command_set, raise_on_error=True): + actions = [] + for command in command_set: + data = dict( + _index=self.index, + _type=self.doc_type, + _source=self.make_data(command), ) + actions.append(data) + return bulk(self.es, actions, index=self.index, raise_on_error=raise_on_error) + + def save(self, command): + """ + 保存命令到数据库 + """ + data = self.make_data(command) + return self.es.index(index=self.index, doc_type=self.doc_type, body=data) + + def filter(self, query: dict, from_=None, size=None, sort=None): + body = self.get_query_body(**query) + + data = self.es.search( + index=self.index, doc_type=self.doc_type, body=body, from_=from_, size=size, + sort=sort + ) + + return AbstractSessionCommand.from_multi_dict( + [item['_source'] for item in data['hits']['hits'] if item] + ) + + def count(self, **query): + body = self.get_query_body(**query) + data = self.es.count(index=self.index, doc_type=self.doc_type, body=body) + return data["count"] + + def __getattr__(self, item): + return getattr(self.es, item) - def count(self, date_from=None, date_to=None, user=None, asset=None, - system_user=None, input=None, session=None): + def all(self): + """返回所有数据""" + raise NotImplementedError("Not support") + + def ping(self): try: - count = super().count( - date_from=date_from, date_to=date_to, user=user, asset=asset, - system_user=system_user, input=input, session=session - ) - except Exception as e: - logger.error(e, exc_info=True) - return 0 - else: - return count + return self.es.ping() + except Exception: + return False + + @staticmethod + def get_query_body(**kwargs): + new_kwargs = {} + for k, v in kwargs.items(): + new_kwargs[k] = str(v) if isinstance(v, UUID) else v + kwargs = new_kwargs + + exact_fields = {} + match_fields = {'session', 'input', 'org_id', 'risk_level', 'user', 'asset', 'system_user'} + + match = {} + exact = {} + + for k, v in kwargs.items(): + if k in exact_fields: + exact[k] = v + elif k in match_fields: + match[k] = v + + # 处理时间 + timestamp__gte = kwargs.get('timestamp__gte') + timestamp__lte = kwargs.get('timestamp__lte') + timestamp_range = {} + + if timestamp__gte: + timestamp_range['gte'] = timestamp__gte + if timestamp__lte: + timestamp_range['lte'] = timestamp__lte + + # 处理组织 + must_not = [] + org_id = match.get('org_id') + if org_id == '': + match.pop('org_id') + must_not.append({'wildcard': {'org_id': '*'}}) + + # 构建 body + body = { + 'query': { + 'bool': { + 'must': [ + {'match': {k: v}} for k, v in match.items() + ], + 'must_not': must_not, + 'filter': [ + { + 'term': {k: v} + } for k, v in exact.items() + ] + [ + { + 'range': { + 'timestamp': timestamp_range + } + } + ] + } + }, + } + return body + + +class QuerySet(DJQuerySet): + _method_calls = None + _storage = None + _command_store_config = None + _slice = None # (from_, size) + default_days_ago = 5 + max_result_window = 10000 + + def __init__(self, command_store_config): + self._method_calls = [] + self._command_store_config = command_store_config + self._storage = CommandStore(command_store_config) + + @lazyproperty + def _grouped_method_calls(self): + _method_calls = {k: list(v) for k, v in groupby(self._method_calls, lambda x: x[0])} + return _method_calls + + @lazyproperty + def _filter_kwargs(self): + _method_calls = self._grouped_method_calls + filter_calls = _method_calls.get('filter') + if not filter_calls: + return {} + names, multi_args, multi_kwargs = zip(*filter_calls) + kwargs = reduce(lambda x, y: {**x, **y}, multi_kwargs, {}) + + striped_kwargs = {} + for k, v in kwargs.items(): + k = k.replace('__exact', '') + k = k.replace('__startswith', '') + k = k.replace('__icontains', '') + striped_kwargs[k] = v + return striped_kwargs + + @lazyproperty + def _sort(self): + order_by = self._grouped_method_calls.get('order_by') + if order_by: + for call in reversed(order_by): + fields = call[1] + if fields: + field = fields[-1] + + if field.startswith('-'): + direction = 'desc' + else: + direction = 'asc' + field = field.lstrip('-+') + sort = f'{field}:{direction}' + return sort + + def __execute(self): + _filter_kwargs = self._filter_kwargs + _sort = self._sort + from_, size = self._slice or (None, None) + data = self._storage.filter(_filter_kwargs, from_=from_, size=size, sort=_sort) + return data + + def __stage_method_call(self, item, *args, **kwargs): + _clone = self.__clone() + _clone._method_calls.append((item, args, kwargs)) + return _clone + + def __clone(self): + uqs = QuerySet(self._command_store_config) + uqs._method_calls = self._method_calls.copy() + uqs._slice = self._slice + return uqs + + def count(self, limit_to_max_result_window=True): + filter_kwargs = self._filter_kwargs + count = self._storage.count(**filter_kwargs) + if limit_to_max_result_window: + count = min(count, self.max_result_window) + return count + + def __getattribute__(self, item): + if any(( + item.startswith('__'), + item in QuerySet.__dict__, + )): + return object.__getattribute__(self, item) + + origin_attr = object.__getattribute__(self, item) + if not inspect.ismethod(origin_attr): + return origin_attr + + attr = partial(self.__stage_method_call, item) + return attr + + def __getitem__(self, item): + max_window = self.max_result_window + if isinstance(item, slice): + if self._slice is None: + clone = self.__clone() + from_ = item.start or 0 + if item.stop is None: + size = 10 + else: + size = item.stop - from_ + + if from_ + size > max_window: + if from_ >= max_window: + from_ = max_window + size = 0 + else: + size = max_window - from_ + clone._slice = (from_, size) + return clone + return self.__execute()[item] + + def __repr__(self): + return self.__execute().__repr__() + + def __iter__(self): + return iter(self.__execute()) + + def __len__(self): + return self.count() diff --git a/apps/terminal/filters.py b/apps/terminal/filters.py new file mode 100644 index 000000000..caed19a9c --- /dev/null +++ b/apps/terminal/filters.py @@ -0,0 +1,82 @@ +from django_filters import rest_framework as filters +from django.db.models import QuerySet + +from orgs.utils import current_org +from terminal.models import Command, CommandStorage + + +class CommandFilter(filters.FilterSet): + date_from = filters.DateTimeFilter(method='do_nothing') + date_to = filters.DateTimeFilter(method='do_nothing') + session_id = filters.CharFilter(field_name='session') + command_storage_id = filters.UUIDFilter(method='do_nothing') + user = filters.CharFilter(lookup_expr='startswith') + input = filters.CharFilter(lookup_expr='icontains') + + class Meta: + model = Command + fields = [ + 'asset', 'system_user', 'user', 'session', 'risk_level', 'input', + 'date_from', 'date_to', 'session_id', 'risk_level', 'command_storage_id', + ] + + def do_nothing(self, queryset, name, value): + return queryset + + @property + def qs(self): + qs = super().qs + qs = qs.filter(org_id=self.get_org_id()) + qs = self.filter_by_timestamp(qs) + return qs + + def filter_by_timestamp(self, qs: QuerySet): + date_from = self.form.cleaned_data.get('date_from') + date_to = self.form.cleaned_data.get('date_to') + + filters = {} + if date_from: + date_from = date_from.timestamp() + filters['timestamp__gte'] = date_from + + if date_to: + date_to = date_to.timestamp() + filters['timestamp__lte'] = date_to + + qs = qs.filter(**filters) + return qs + + @staticmethod + def get_org_id(): + if current_org.is_default(): + org_id = '' + else: + org_id = current_org.id + return org_id + + +class CommandFilterForStorageTree(CommandFilter): + asset = filters.CharFilter(method='do_nothing') + system_user = filters.CharFilter(method='do_nothing') + session = filters.CharFilter(method='do_nothing') + risk_level = filters.NumberFilter(method='do_nothing') + + class Meta: + model = CommandStorage + fields = [ + 'asset', 'system_user', 'user', 'session', 'risk_level', 'input', + 'date_from', 'date_to', 'session_id', 'risk_level', 'command_storage_id', + ] + + +class CommandStorageFilter(filters.FilterSet): + real = filters.BooleanFilter(method='filter_real') + + class Meta: + model = CommandStorage + fields = ['real', 'name', 'type'] + + def filter_real(self, queryset, name, value): + if value: + queryset = queryset.exclude(name='null') + return queryset diff --git a/apps/terminal/models/storage.py b/apps/terminal/models/storage.py index b74feae40..3f574985c 100644 --- a/apps/terminal/models/storage.py +++ b/apps/terminal/models/storage.py @@ -1,16 +1,24 @@ from __future__ import unicode_literals import os +from importlib import import_module + import jms_storage from django.db import models from django.utils.translation import ugettext_lazy as _ from django.conf import settings from common.mixins import CommonModelMixin +from common.utils import get_logger from common.fields.model import EncryptJsonDictTextField +from terminal.backends import TYPE_ENGINE_MAPPING from .terminal import Terminal +from .command import Command from .. import const +logger = get_logger(__file__) + + class CommandStorage(CommonModelMixin): name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True) type = models.CharField( @@ -50,6 +58,18 @@ class CommandStorage(CommonModelMixin): def is_use(self): return Terminal.objects.filter(command_storage=self.name).exists() + def get_command_queryset(self): + if self.type_server: + qs = Command.objects.all() + else: + if self.type not in TYPE_ENGINE_MAPPING: + logger.error(f'Command storage `{self.type}` not support') + return Command.objects.none() + engine_mod = import_module(TYPE_ENGINE_MAPPING[self.type]) + qs = engine_mod.QuerySet(self.config) + qs.model = Command + return qs + class ReplayStorage(CommonModelMixin): name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)