# -*- coding: utf-8 -*- # import datetime import inspect from collections.abc import Iterable from functools import reduce, partial from itertools import groupby from uuid import UUID from django.utils.translation import gettext_lazy as _ from django.db.models import QuerySet as DJQuerySet from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk from elasticsearch.exceptions import RequestError, NotFoundError from common.utils.common import lazyproperty from common.utils import get_logger from common.utils.timezone import local_now_date_display from common.exceptions import JMSException logger = get_logger(__file__) class InvalidElasticsearch(JMSException): default_code = 'invalid_elasticsearch' default_detail = _('Invalid elasticsearch config') class NotSupportElasticsearch8(JMSException): default_code = 'not_support_elasticsearch8' default_detail = _('Not Support Elasticsearch8') class ES(object): def __init__(self, config, properties, keyword_fields, exact_fields=None, match_fields=None): self.config = config hosts = self.config.get('HOSTS') kwargs = self.config.get('OTHER', {}) ignore_verify_certs = kwargs.pop('IGNORE_VERIFY_CERTS', False) if ignore_verify_certs: kwargs['verify_certs'] = None self.es = Elasticsearch(hosts=hosts, max_retries=0, **kwargs) self.index_prefix = self.config.get('INDEX') or 'jumpserver' self.is_index_by_date = bool(self.config.get('INDEX_BY_DATE', False)) self.index = None self.query_index = None self.properties = properties self.exact_fields, self.match_fields, self.keyword_fields = set(), set(), set() if isinstance(keyword_fields, Iterable): self.keyword_fields.update(keyword_fields) if isinstance(exact_fields, Iterable): self.exact_fields.update(exact_fields) if isinstance(match_fields, Iterable): self.match_fields.update(match_fields) self.init_index() self.doc_type = self.config.get("DOC_TYPE") or '_doc' if self.is_new_index_type(): self.doc_type = '_doc' self.exact_fields.update(self.keyword_fields) else: self.match_fields.update(self.keyword_fields) def init_index(self): if self.is_index_by_date: date = local_now_date_display() self.index = '%s-%s' % (self.index_prefix, date) self.query_index = '%s-alias' % self.index_prefix else: self.index = self.config.get("INDEX") or 'jumpserver' self.query_index = self.config.get("INDEX") or 'jumpserver' def is_new_index_type(self): if not self.ping(timeout=2): return False info = self.es.info() version = info['version']['number'].split('.')[0] if version == '8': raise NotSupportElasticsearch8 try: # 获取索引信息,如果没有定义,直接返回 data = self.es.indices.get_mapping(self.index) except NotFoundError: return False try: if version == '6': # 检测索引是不是新的类型 es6 properties = data[self.index]['mappings']['data']['properties'] else: # 检测索引是不是新的类型 es7 default index type: _doc properties = data[self.index]['mappings']['properties'] for keyword in self.keyword_fields: if not properties[keyword]['type'] == 'keyword': break else: return True except KeyError: return False def pre_use_check(self): if not self.ping(timeout=3): raise InvalidElasticsearch self._ensure_index_exists() def _ensure_index_exists(self): info = self.es.info() version = info['version']['number'].split('.')[0] if version == '6': mappings = {'mappings': {'data': {'properties': self.properties}}} else: mappings = {'mappings': {'properties': self.properties}} if self.is_index_by_date: mappings['aliases'] = { self.query_index: {} } try: self.es.indices.create(self.index, body=mappings) return except RequestError as e: if e.error == 'resource_already_exists_exception': logger.warning(e) else: logger.exception(e) def make_data(self, data): return [] def save(self, **kwargs): data = self.make_data(kwargs) return self.es.index(index=self.index, doc_type=self.doc_type, body=data) def bulk_save(self, command_set, raise_on_error=True): actions = [] for command in command_set: data = dict( _index=self.index, _type=self.doc_type, _source=self.make_data(command), ) actions.append(data) return bulk(self.es, actions, index=self.index, raise_on_error=raise_on_error) def get(self, query: dict): item = None data = self.filter(query, size=1) if len(data) >= 1: item = data[0] return item def filter(self, query: dict, from_=None, size=None, sort=None): try: data = self._filter(query, from_, size, sort) except Exception as e: logger.error('ES filter error: {}'.format(e)) data = [] return data def _filter(self, query: dict, from_=None, size=None, sort=None): body = self.get_query_body(**query) data = self.es.search( index=self.query_index, doc_type=self.doc_type, body=body, from_=from_, size=size, sort=sort ) source_data = [] for item in data['hits']['hits']: if item: item['_source'].update({'es_id': item['_id']}) source_data.append(item['_source']) return source_data def count(self, **query): try: body = self.get_query_body(**query) data = self.es.count(index=self.query_index, doc_type=self.doc_type, body=body) count = data["count"] except Exception as e: logger.error('ES count error: {}'.format(e)) count = 0 return count def __getattr__(self, item): return getattr(self.es, item) def all(self): """返回所有数据""" raise NotImplementedError("Not support") def ping(self, timeout=None): try: return self.es.ping(request_timeout=timeout) except Exception: return False @staticmethod def handler_time_field(data): datetime__gte = data.get('datetime__gte') datetime__lte = data.get('datetime__lte') datetime_range = {} if datetime__gte: if isinstance(datetime__gte, datetime.datetime): datetime__gte = datetime__gte.strftime('%Y-%m-%d %H:%M:%S') datetime_range['gte'] = datetime__gte if datetime__lte: if isinstance(datetime__lte, datetime.datetime): datetime__lte = datetime__lte.strftime('%Y-%m-%d %H:%M:%S') datetime_range['lte'] = datetime__lte return 'datetime', datetime_range def get_query_body(self, **kwargs): new_kwargs = {} for k, v in kwargs.items(): if isinstance(v, UUID): v = str(v) if k == 'pk': k = 'id' new_kwargs[k] = v kwargs = new_kwargs index_in_field = 'id__in' exact_fields = self.exact_fields match_fields = self.match_fields match = {} exact = {} index = {} if index_in_field in kwargs: index['values'] = kwargs[index_in_field] for k, v in kwargs.items(): if k in exact_fields: exact[k] = v elif k in match_fields: match[k] = v # 处理时间 time_field_name, time_range = self.handler_time_field(kwargs) # 处理组织 should = [] org_id = match.get('org_id') real_default_org_id = '00000000-0000-0000-0000-000000000002' root_org_id = '00000000-0000-0000-0000-000000000000' if org_id == root_org_id: match.pop('org_id') elif org_id in (real_default_org_id, ''): match.pop('org_id') should.append({ 'bool': { 'must_not': [ { 'wildcard': {'org_id': '*'} } ]} }) should.append({'match': {'org_id': real_default_org_id}}) # 构建 body body = { 'query': { 'bool': { 'must': [ {'match': {k: v}} for k, v in match.items() ], 'should': should, 'filter': [ { 'term': {k: v} } for k, v in exact.items() ] + [ { 'range': { time_field_name: time_range } } ] + [ { 'ids': {k: v} } for k, v in index.items() ] } }, } return body class QuerySet(DJQuerySet): default_days_ago = 7 max_result_window = 10000 def __init__(self, es_instance): self._method_calls = [] self._slice = None # (from_, size) self._storage = es_instance # 命令列表模糊搜索时报错 super().__init__() @lazyproperty def _grouped_method_calls(self): _method_calls = {k: list(v) for k, v in groupby(self._method_calls, lambda x: x[0])} return _method_calls @lazyproperty def _filter_kwargs(self): _method_calls = self._grouped_method_calls filter_calls = _method_calls.get('filter') if not filter_calls: return {} names, multi_args, multi_kwargs = zip(*filter_calls) kwargs = reduce(lambda x, y: {**x, **y}, multi_kwargs, {}) striped_kwargs = {} for k, v in kwargs.items(): k = k.replace('__exact', '') k = k.replace('__startswith', '') k = k.replace('__icontains', '') striped_kwargs[k] = v return striped_kwargs @lazyproperty def _sort(self): order_by = self._grouped_method_calls.get('order_by') if order_by: for call in reversed(order_by): fields = call[1] if fields: field = fields[-1] if field.startswith('-'): direction = 'desc' else: direction = 'asc' field = field.lstrip('-+') sort = f'{field}:{direction}' return sort def __execute(self): _filter_kwargs = self._filter_kwargs _sort = self._sort from_, size = self._slice or (None, None) data = self._storage.filter(_filter_kwargs, from_=from_, size=size, sort=_sort) return self.model.from_multi_dict(data) def __stage_method_call(self, item, *args, **kwargs): _clone = self.__clone() _clone._method_calls.append((item, args, kwargs)) return _clone def __clone(self): uqs = QuerySet(self._storage) uqs._method_calls = self._method_calls.copy() uqs._slice = self._slice uqs.model = self.model return uqs def get(self, **kwargs): kwargs.update(self._filter_kwargs) return self._storage.get(kwargs) def count(self, limit_to_max_result_window=True): filter_kwargs = self._filter_kwargs count = self._storage.count(**filter_kwargs) if limit_to_max_result_window: count = min(count, self.max_result_window) return count def __getattribute__(self, item): if any(( item.startswith('__'), item in QuerySet.__dict__, )): return object.__getattribute__(self, item) origin_attr = object.__getattribute__(self, item) if not inspect.ismethod(origin_attr): return origin_attr attr = partial(self.__stage_method_call, item) return attr def __getitem__(self, item): max_window = self.max_result_window if isinstance(item, slice): if self._slice is None: clone = self.__clone() from_ = item.start or 0 if item.stop is None: size = self.max_result_window - from_ else: size = item.stop - from_ if from_ + size > max_window: if from_ >= max_window: from_ = max_window size = 0 else: size = max_window - from_ clone._slice = (from_, size) return clone return self.__execute()[item] def __repr__(self): return self.__execute().__repr__() def __iter__(self): return iter(self.__execute()) def __len__(self): return self.count()