perf: 优化命令记录慢的问题

pull/5698/head^2
xinwen 4 years ago
parent 7f42e59714
commit 3e7e01418d

@ -19,7 +19,10 @@ def extract_object_name(exc, index=0):
`No User matches the given query.`
提取 `User``index=1`
"""
(msg, *_) = exc.args
if exc.args:
(msg, *others) = exc.args
else:
return gettext('Object')
return gettext(msg.split(sep=' ', maxsplit=index + 1)[index])

@ -97,6 +97,8 @@ class SimpleMetadataWithFilters(SimpleMetadata):
fields = view.filterset_fields
elif hasattr(view, 'get_filterset_fields'):
fields = view.get_filterset_fields(request)
elif hasattr(view, 'filterset_class'):
fields = view.filterset_class.Meta.fields
if isinstance(fields, dict):
fields = list(fields.keys())

@ -7,7 +7,7 @@ from common.utils import get_logger
from perms.pagination import NodeGrantedAssetPagination, AllGrantedAssetPagination
from assets.models import Asset, Node
from perms import serializers
from perms.utils.asset.user_permission import UserGrantedAssetsQueryUtils, QuerySetStage
from perms.utils.asset.user_permission import UserGrantedAssetsQueryUtils
logger = get_logger(__name__)

@ -53,122 +53,6 @@ def get_user_all_asset_perm_ids(user) -> set:
return asset_perm_ids
class QuerySetStage:
def __init__(self):
self._prefetch_related = set()
self._only = ()
self._filters = []
self._querysets_and = []
self._querysets_or = []
self._order_by = None
self._annotate = []
self._before_union_merge_funs = set()
self._after_union_merge_funs = set()
def annotate(self, *args, **kwargs):
self._annotate.append((args, kwargs))
self._before_union_merge_funs.add(self._merge_annotate)
return self
def prefetch_related(self, *lookups):
self._prefetch_related.update(lookups)
self._before_union_merge_funs.add(self._merge_prefetch_related)
return self
def only(self, *fields):
self._only = fields
self._before_union_merge_funs.add(self._merge_only)
return self
def order_by(self, *field_names):
self._order_by = field_names
self._after_union_merge_funs.add(self._merge_order_by)
return self
def filter(self, *args, **kwargs):
self._filters.append((args, kwargs))
self._before_union_merge_funs.add(self._merge_filters)
return self
def and_with_queryset(self, qs: QuerySet):
assert isinstance(qs, QuerySet), f'Must be `QuerySet`'
self._order_by = qs.query.order_by
self._after_union_merge_funs.add(self._merge_order_by)
self._querysets_and.append(qs.order_by())
self._before_union_merge_funs.add(self._merge_querysets_and)
return self
def or_with_queryset(self, qs: QuerySet):
assert isinstance(qs, QuerySet), f'Must be `QuerySet`'
self._order_by = qs.query.order_by
self._after_union_merge_funs.add(self._merge_order_by)
self._querysets_or.append(qs.order_by())
self._before_union_merge_funs.add(self._merge_querysets_or)
return self
def merge_multi_before_union(self, *querysets):
ret = []
for qs in querysets:
qs = self.merge_before_union(qs)
ret.append(qs)
return ret
def _merge_only(self, qs: QuerySet):
if self._only:
qs = qs.only(*self._only)
return qs
def _merge_filters(self, qs: QuerySet):
if self._filters:
for args, kwargs in self._filters:
qs = qs.filter(*args, **kwargs)
return qs
def _merge_querysets_and(self, qs: QuerySet):
if self._querysets_and:
for qs_and in self._querysets_and:
qs &= qs_and
return qs
def _merge_annotate(self, qs: QuerySet):
if self._annotate:
for args, kwargs in self._annotate:
qs = qs.annotate(*args, **kwargs)
return qs
def _merge_querysets_or(self, qs: QuerySet):
if self._querysets_or:
for qs_or in self._querysets_or:
qs |= qs_or
return qs
def _merge_prefetch_related(self, qs: QuerySet):
if self._prefetch_related:
qs = qs.prefetch_related(*self._prefetch_related)
return qs
def _merge_order_by(self, qs: QuerySet):
if self._order_by is not None:
qs = qs.order_by(*self._order_by)
return qs
def merge_before_union(self, qs: QuerySet) -> QuerySet:
assert isinstance(qs, QuerySet), f'Must be `QuerySet`'
for fun in self._before_union_merge_funs:
qs = fun(qs)
return qs
def merge_after_union(self, qs: QuerySet) -> QuerySet:
for fun in self._after_union_merge_funs:
qs = fun(qs)
return qs
def merge(self, qs: QuerySet) -> QuerySet:
qs = self.merge_before_union(qs)
qs = self.merge_after_union(qs)
return qs
class UserGrantedTreeRefreshController:
key_template = 'perms.user.node_tree.builded_orgs.user_id:{user_id}'

@ -8,11 +8,14 @@ from rest_framework import viewsets
from rest_framework import generics
from rest_framework.fields import DateTimeField
from rest_framework.response import Response
from rest_framework import status
from rest_framework.decorators import action
from django.template import loader
from terminal.models import CommandStorage
from terminal.filters import CommandFilter
from orgs.utils import current_org
from common.permissions import IsOrgAdminOrAppUser, IsOrgAuditor, IsAppUser
from common.const.http import GET
from common.utils import get_logger
from terminal.utils import send_command_alert_mail
from terminal.serializers import InsecureCommandAlertSerializer
@ -89,7 +92,7 @@ class CommandQueryMixin:
return date_from_st, date_to_st
class CommandViewSet(CommandQueryMixin, viewsets.ModelViewSet):
class CommandViewSet(viewsets.ModelViewSet):
"""接受app发送来的command log, 格式如下
{
"user": "admin",
@ -103,7 +106,16 @@ class CommandViewSet(CommandQueryMixin, viewsets.ModelViewSet):
"""
command_store = get_command_storage()
permission_classes = [IsOrgAdminOrAppUser | IsOrgAuditor]
serializer_class = SessionCommandSerializer
filterset_class = CommandFilter
ordering_fields = ('timestamp', )
def get_queryset(self):
command_storage_id = self.request.query_params.get('command_storage_id')
storage = CommandStorage.objects.get(id=command_storage_id)
qs = storage.get_command_queryset()
return qs
def create(self, request, *args, **kwargs):
serializer = self.serializer_class(data=request.data, many=True)

@ -3,9 +3,15 @@
from rest_framework import viewsets, generics, status
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.decorators import action
from django.utils.translation import ugettext_lazy as _
from django_filters import utils
from terminal import const
from common.const.http import GET
from common.permissions import IsSuperUser
from terminal.filters import CommandStorageFilter, CommandFilter, CommandFilterForStorageTree
from ..models import CommandStorage, ReplayStorage
from ..serializers import CommandStorageSerializer, ReplayStorageSerializer
@ -30,11 +36,52 @@ class BaseStorageViewSetMixin:
class CommandStorageViewSet(BaseStorageViewSetMixin, viewsets.ModelViewSet):
filterset_fields = ('name', 'type',)
search_fields = filterset_fields
search_fields = ('name', 'type',)
queryset = CommandStorage.objects.all()
serializer_class = CommandStorageSerializer
permission_classes = (IsSuperUser,)
filterset_class = CommandStorageFilter
@action(methods=[GET], detail=False, filterset_class=CommandFilterForStorageTree)
def tree(self, request: Request):
storage_qs = self.get_queryset().exclude(name='null')
storages_with_count = []
for storage in storage_qs:
command_qs = storage.get_command_queryset()
filterset = CommandFilter(
data=request.query_params, queryset=command_qs,
request=request
)
if not filterset.is_valid():
raise utils.translate_validation(filterset.errors)
command_qs = filterset.qs
if storage.type == const.CommandStorageTypeChoices.es:
command_count = command_qs.count(limit_to_max_result_window=False)
else:
command_count = command_qs.count()
storages_with_count.append((storage, command_count))
root = {
'id': 'root',
'name': _('Command storages'),
'title': _('Command storages'),
'pId': '',
'isParent': True,
'open': True,
}
nodes = [
{
'id': storage.id,
'name': f'{storage.name}({storage.type})({command_count})',
'title': f'{storage.name}({storage.type})',
'pId': 'root',
'isParent': False,
'open': False,
} for storage, command_count in storages_with_count
]
nodes.append(root)
return Response(data=nodes)
class ReplayStorageViewSet(BaseStorageViewSetMixin, viewsets.ModelViewSet):

@ -1,53 +1,270 @@
# -*- coding: utf-8 -*-
#
from datetime import datetime
from jms_storage.es import ESStorage
from functools import reduce, partial
from itertools import groupby
import pytz
from uuid import UUID
import inspect
from django.db.models import QuerySet as DJQuerySet
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from common.utils.common import lazyproperty
from common.utils import get_logger
from .base import CommandBase
from .models import AbstractSessionCommand
logger = get_logger(__file__)
class CommandStore(ESStorage, CommandBase):
def __init__(self, params):
super().__init__(params)
class CommandStore():
def __init__(self, config):
hosts = config.get("HOSTS")
kwargs = config.get("OTHER", {})
self.index = config.get("INDEX") or 'jumpserver'
self.doc_type = config.get("DOC_TYPE") or 'command_store'
self.es = Elasticsearch(hosts=hosts, **kwargs)
def filter(self, date_from=None, date_to=None,
user=None, asset=None, system_user=None,
input=None, session=None, risk_level=None, org_id=None):
@staticmethod
def make_data(command):
data = dict(
user=command["user"], asset=command["asset"],
system_user=command["system_user"], input=command["input"],
output=command["output"], risk_level=command["risk_level"],
session=command["session"], timestamp=command["timestamp"],
org_id=command["org_id"]
)
data["date"] = datetime.fromtimestamp(command['timestamp'], tz=pytz.UTC)
return data
if date_from is not None:
if isinstance(date_from, float):
date_from = datetime.fromtimestamp(date_from)
if date_to is not None:
if isinstance(date_to, float):
date_to = datetime.fromtimestamp(date_to)
try:
data = super().filter(date_from=date_from, date_to=date_to,
user=user, asset=asset, system_user=system_user,
input=input, session=session,
risk_level=risk_level, org_id=org_id)
except Exception as e:
logger.error(e, exc_info=True)
return []
else:
return AbstractSessionCommand.from_multi_dict(
[item["_source"] for item in data["hits"] if item]
def bulk_save(self, command_set, raise_on_error=True):
actions = []
for command in command_set:
data = dict(
_index=self.index,
_type=self.doc_type,
_source=self.make_data(command),
)
actions.append(data)
return bulk(self.es, actions, index=self.index, raise_on_error=raise_on_error)
def save(self, command):
"""
保存命令到数据库
"""
data = self.make_data(command)
return self.es.index(index=self.index, doc_type=self.doc_type, body=data)
def filter(self, query: dict, from_=None, size=None, sort=None):
body = self.get_query_body(**query)
data = self.es.search(
index=self.index, doc_type=self.doc_type, body=body, from_=from_, size=size,
sort=sort
)
return AbstractSessionCommand.from_multi_dict(
[item['_source'] for item in data['hits']['hits'] if item]
)
def count(self, **query):
body = self.get_query_body(**query)
data = self.es.count(index=self.index, doc_type=self.doc_type, body=body)
return data["count"]
def __getattr__(self, item):
return getattr(self.es, item)
def count(self, date_from=None, date_to=None, user=None, asset=None,
system_user=None, input=None, session=None):
def all(self):
"""返回所有数据"""
raise NotImplementedError("Not support")
def ping(self):
try:
count = super().count(
date_from=date_from, date_to=date_to, user=user, asset=asset,
system_user=system_user, input=input, session=session
)
except Exception as e:
logger.error(e, exc_info=True)
return 0
else:
return count
return self.es.ping()
except Exception:
return False
@staticmethod
def get_query_body(**kwargs):
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = str(v) if isinstance(v, UUID) else v
kwargs = new_kwargs
exact_fields = {}
match_fields = {'session', 'input', 'org_id', 'risk_level', 'user', 'asset', 'system_user'}
match = {}
exact = {}
for k, v in kwargs.items():
if k in exact_fields:
exact[k] = v
elif k in match_fields:
match[k] = v
# 处理时间
timestamp__gte = kwargs.get('timestamp__gte')
timestamp__lte = kwargs.get('timestamp__lte')
timestamp_range = {}
if timestamp__gte:
timestamp_range['gte'] = timestamp__gte
if timestamp__lte:
timestamp_range['lte'] = timestamp__lte
# 处理组织
must_not = []
org_id = match.get('org_id')
if org_id == '':
match.pop('org_id')
must_not.append({'wildcard': {'org_id': '*'}})
# 构建 body
body = {
'query': {
'bool': {
'must': [
{'match': {k: v}} for k, v in match.items()
],
'must_not': must_not,
'filter': [
{
'term': {k: v}
} for k, v in exact.items()
] + [
{
'range': {
'timestamp': timestamp_range
}
}
]
}
},
}
return body
class QuerySet(DJQuerySet):
_method_calls = None
_storage = None
_command_store_config = None
_slice = None # (from_, size)
default_days_ago = 5
max_result_window = 10000
def __init__(self, command_store_config):
self._method_calls = []
self._command_store_config = command_store_config
self._storage = CommandStore(command_store_config)
@lazyproperty
def _grouped_method_calls(self):
_method_calls = {k: list(v) for k, v in groupby(self._method_calls, lambda x: x[0])}
return _method_calls
@lazyproperty
def _filter_kwargs(self):
_method_calls = self._grouped_method_calls
filter_calls = _method_calls.get('filter')
if not filter_calls:
return {}
names, multi_args, multi_kwargs = zip(*filter_calls)
kwargs = reduce(lambda x, y: {**x, **y}, multi_kwargs, {})
striped_kwargs = {}
for k, v in kwargs.items():
k = k.replace('__exact', '')
k = k.replace('__startswith', '')
k = k.replace('__icontains', '')
striped_kwargs[k] = v
return striped_kwargs
@lazyproperty
def _sort(self):
order_by = self._grouped_method_calls.get('order_by')
if order_by:
for call in reversed(order_by):
fields = call[1]
if fields:
field = fields[-1]
if field.startswith('-'):
direction = 'desc'
else:
direction = 'asc'
field = field.lstrip('-+')
sort = f'{field}:{direction}'
return sort
def __execute(self):
_filter_kwargs = self._filter_kwargs
_sort = self._sort
from_, size = self._slice or (None, None)
data = self._storage.filter(_filter_kwargs, from_=from_, size=size, sort=_sort)
return data
def __stage_method_call(self, item, *args, **kwargs):
_clone = self.__clone()
_clone._method_calls.append((item, args, kwargs))
return _clone
def __clone(self):
uqs = QuerySet(self._command_store_config)
uqs._method_calls = self._method_calls.copy()
uqs._slice = self._slice
return uqs
def count(self, limit_to_max_result_window=True):
filter_kwargs = self._filter_kwargs
count = self._storage.count(**filter_kwargs)
if limit_to_max_result_window:
count = min(count, self.max_result_window)
return count
def __getattribute__(self, item):
if any((
item.startswith('__'),
item in QuerySet.__dict__,
)):
return object.__getattribute__(self, item)
origin_attr = object.__getattribute__(self, item)
if not inspect.ismethod(origin_attr):
return origin_attr
attr = partial(self.__stage_method_call, item)
return attr
def __getitem__(self, item):
max_window = self.max_result_window
if isinstance(item, slice):
if self._slice is None:
clone = self.__clone()
from_ = item.start or 0
if item.stop is None:
size = 10
else:
size = item.stop - from_
if from_ + size > max_window:
if from_ >= max_window:
from_ = max_window
size = 0
else:
size = max_window - from_
clone._slice = (from_, size)
return clone
return self.__execute()[item]
def __repr__(self):
return self.__execute().__repr__()
def __iter__(self):
return iter(self.__execute())
def __len__(self):
return self.count()

@ -0,0 +1,82 @@
from django_filters import rest_framework as filters
from django.db.models import QuerySet
from orgs.utils import current_org
from terminal.models import Command, CommandStorage
class CommandFilter(filters.FilterSet):
date_from = filters.DateTimeFilter(method='do_nothing')
date_to = filters.DateTimeFilter(method='do_nothing')
session_id = filters.CharFilter(field_name='session')
command_storage_id = filters.UUIDFilter(method='do_nothing')
user = filters.CharFilter(lookup_expr='startswith')
input = filters.CharFilter(lookup_expr='icontains')
class Meta:
model = Command
fields = [
'asset', 'system_user', 'user', 'session', 'risk_level', 'input',
'date_from', 'date_to', 'session_id', 'risk_level', 'command_storage_id',
]
def do_nothing(self, queryset, name, value):
return queryset
@property
def qs(self):
qs = super().qs
qs = qs.filter(org_id=self.get_org_id())
qs = self.filter_by_timestamp(qs)
return qs
def filter_by_timestamp(self, qs: QuerySet):
date_from = self.form.cleaned_data.get('date_from')
date_to = self.form.cleaned_data.get('date_to')
filters = {}
if date_from:
date_from = date_from.timestamp()
filters['timestamp__gte'] = date_from
if date_to:
date_to = date_to.timestamp()
filters['timestamp__lte'] = date_to
qs = qs.filter(**filters)
return qs
@staticmethod
def get_org_id():
if current_org.is_default():
org_id = ''
else:
org_id = current_org.id
return org_id
class CommandFilterForStorageTree(CommandFilter):
asset = filters.CharFilter(method='do_nothing')
system_user = filters.CharFilter(method='do_nothing')
session = filters.CharFilter(method='do_nothing')
risk_level = filters.NumberFilter(method='do_nothing')
class Meta:
model = CommandStorage
fields = [
'asset', 'system_user', 'user', 'session', 'risk_level', 'input',
'date_from', 'date_to', 'session_id', 'risk_level', 'command_storage_id',
]
class CommandStorageFilter(filters.FilterSet):
real = filters.BooleanFilter(method='filter_real')
class Meta:
model = CommandStorage
fields = ['real', 'name', 'type']
def filter_real(self, queryset, name, value):
if value:
queryset = queryset.exclude(name='null')
return queryset

@ -1,16 +1,24 @@
from __future__ import unicode_literals
import os
from importlib import import_module
import jms_storage
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from common.mixins import CommonModelMixin
from common.utils import get_logger
from common.fields.model import EncryptJsonDictTextField
from terminal.backends import TYPE_ENGINE_MAPPING
from .terminal import Terminal
from .command import Command
from .. import const
logger = get_logger(__file__)
class CommandStorage(CommonModelMixin):
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)
type = models.CharField(
@ -50,6 +58,18 @@ class CommandStorage(CommonModelMixin):
def is_use(self):
return Terminal.objects.filter(command_storage=self.name).exists()
def get_command_queryset(self):
if self.type_server:
qs = Command.objects.all()
else:
if self.type not in TYPE_ENGINE_MAPPING:
logger.error(f'Command storage `{self.type}` not support')
return Command.objects.none()
engine_mod = import_module(TYPE_ENGINE_MAPPING[self.type])
qs = engine_mod.QuerySet(self.config)
qs.model = Command
return qs
class ReplayStorage(CommonModelMixin):
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)

Loading…
Cancel
Save