# -*- coding: utf-8 -*- # from django.conf import settings from django.utils import timezone from rest_framework import generics from rest_framework.fields import DateTimeField from rest_framework.response import Response from terminal.models import CommandStorage, Session, Command from terminal.filters import CommandFilter from orgs.utils import current_org from common.api import JMSBulkModelViewSet from common.utils import get_logger from terminal.backends.command.serializers import InsecureCommandAlertSerializer from terminal.exceptions import StorageInvalid from terminal.backends import ( get_command_storage, get_multi_command_storage, SessionCommandSerializer, ) from terminal.notifications import CommandAlertMessage logger = get_logger(__name__) __all__ = ['CommandViewSet', 'InsecureCommandAlertAPI'] class CommandQueryMixin: command_store = get_command_storage() filterset_fields = [ "asset", "account", "user", "session", "risk_level", "input" ] default_days_ago = 5 @staticmethod def get_org_id(): if current_org.is_default(): org_id = '' else: org_id = current_org.id return org_id def get_query_risk_level(self): risk_level = self.request.query_params.get('risk_level') if risk_level is None: return None if risk_level.isdigit(): return int(risk_level) return None def get_queryset(self): # 解决访问 /docs/ 问题 if hasattr(self, 'swagger_fake_view'): return self.command_store.model.objects.none() date_from, date_to = self.get_date_range() q = self.request.query_params multi_command_storage = get_multi_command_storage() queryset = multi_command_storage.filter( date_from=date_from, date_to=date_to, user=q.get("user"), asset=q.get("asset"), account=q.get("account"), input=q.get("input"), session=q.get("session_id", q.get('session')), risk_level=self.get_query_risk_level(), org_id=self.get_org_id(), ) return queryset def filter_queryset(self, queryset): # 解决es存储命令时,父类根据filter_fields过滤出现异常的问题,返回的queryset类型list return queryset def get_date_range(self): now = timezone.now() days_ago = now - timezone.timedelta(days=self.default_days_ago) date_from_st = days_ago.timestamp() date_to_st = now.timestamp() query_params = self.request.query_params date_from_q = query_params.get("date_from") date_to_q = query_params.get("date_to") dt_parser = DateTimeField().to_internal_value if date_from_q: date_from_st = dt_parser(date_from_q).timestamp() if date_to_q: date_to_st = dt_parser(date_to_q).timestamp() return date_from_st, date_to_st class CommandViewSet(JMSBulkModelViewSet): """接受app发送来的command log, 格式如下 { "user": "admin", "asset": "localhost", "account": "web", "session": "xxxxxx", "input": "whoami", "output": "d2hvbWFp", # base64.b64encode(s) "timestamp": 1485238673.0 } """ command_store = get_command_storage() serializer_class = SessionCommandSerializer filterset_class = CommandFilter search_fields = ('input',) model = Command ordering_fields = ('timestamp',) def merge_all_storage_list(self, request, *args, **kwargs): merged_commands = [] storages = CommandStorage.objects.all() for storage in storages: if not storage.is_valid(): continue qs = storage.get_command_queryset() commands = self.filter_queryset(qs) merged_commands.extend(commands[:]) # ES 默认只取 10 条数据 order = self.request.query_params.get('order', None) if order == 'timestamp': merged_commands.sort(key=lambda command: command.timestamp) else: merged_commands.sort(key=lambda command: command.timestamp, reverse=True) page = self.paginate_queryset(merged_commands) if page is not None: serializer = self.get_serializer(page, many=True) return self.get_paginated_response(serializer.data) serializer = self.get_serializer(merged_commands, many=True) return Response(serializer.data) def list(self, request, *args, **kwargs): command_storage_id = self.request.query_params.get('command_storage_id') session_id = self.request.query_params.get('session_id') if session_id and not command_storage_id: # 会话里的命令列表肯定会提供 session_id,这里防止 merge 的时候取全量的数据 return self.merge_all_storage_list(request, *args, **kwargs) queryset = self.filter_queryset(self.get_queryset()) page = self.paginate_queryset(queryset) if page is not None: page = self.load_remote_addr(page) serializer = self.get_serializer(page, many=True) return self.get_paginated_response(serializer.data) # 适配像 ES 这种没有指定分页只返回少量数据的情况 queryset = queryset[:] queryset = self.load_remote_addr(queryset) serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) def load_remote_addr(self, queryset): commands = list(queryset) session_ids = {command.session for command in commands} sessions = Session.objects.filter(id__in=session_ids).values_list('id', 'remote_addr') session_addr_map = {str(i): addr for i, addr in sessions} for command in commands: command.remote_addr = session_addr_map.get(command.session, '') return commands def get_queryset(self): command_storage_id = self.request.query_params.get('command_storage_id') if not command_storage_id: return Command.objects.none() storage = CommandStorage.objects.get(id=command_storage_id) if not storage.is_valid(): raise StorageInvalid else: qs = storage.get_command_queryset() return qs def create(self, request, *args, **kwargs): serializer = self.serializer_class(data=request.data, many=True) if serializer.is_valid(): ok = self.command_store.bulk_save(serializer.validated_data) if ok: return Response("ok", status=201) else: return Response("Save error", status=500) else: msg = "Command not valid: {}".format(serializer.errors) logger.error(msg) return Response({"msg": msg}, status=401) class InsecureCommandAlertAPI(generics.CreateAPIView): serializer_class = InsecureCommandAlertSerializer rbac_perms = { 'POST': 'terminal.add_command' } def post(self, request, *args, **kwargs): serializer = InsecureCommandAlertSerializer(data=request.data, many=True) serializer.is_valid(raise_exception=True) commands = serializer.validated_data for command in commands: if command['risk_level'] >= settings.SECURITY_INSECURE_COMMAND_LEVEL: CommandAlertMessage(command).publish_async() return Response()