# -*- coding: utf-8 -*- # from django.utils import timezone from rest_framework import generics from rest_framework.fields import DateTimeField from rest_framework.response import Response from acls.models import CommandFilterACL, CommandGroup from common.api import JMSBulkModelViewSet from common.utils import get_logger from orgs.utils import current_org from terminal.backends import ( get_command_storage, get_multi_command_storage ) from terminal.const import RiskLevelChoices from terminal.exceptions import StorageInvalid from terminal.filters import CommandFilter from terminal.models import CommandStorage, Session, Command from terminal.notifications import CommandAlertMessage, CommandWarningMessage from terminal.serializers import ( SessionCommandSerializer, InsecureCommandAlertSerializer ) logger = get_logger(__name__) __all__ = ['CommandViewSet', 'InsecureCommandAlertAPI'] class CommandQueryMixin: command_store = get_command_storage() filterset_fields = [ "asset", "account", "user", "session", "risk_level", "input" ] default_days_ago = 5 @staticmethod def get_org_id(): if current_org.is_default(): org_id = '' else: org_id = current_org.id return org_id def get_query_risk_level(self): risk_level = self.request.query_params.get('risk_level') if risk_level is None: return None if risk_level.isdigit(): return int(risk_level) return None def get_queryset(self): # 解决访问 /docs/ 问题 if hasattr(self, 'swagger_fake_view'): return self.command_store.model.objects.none() date_from, date_to = self.get_date_range() q = self.request.query_params multi_command_storage = get_multi_command_storage() queryset = multi_command_storage.filter( date_from=date_from, date_to=date_to, user=q.get("user"), asset=q.get("asset"), account=q.get("account"), input=q.get("input"), session=q.get("session_id", q.get('session')), risk_level=self.get_query_risk_level(), org_id=self.get_org_id(), ) return queryset def filter_queryset(self, queryset): # 解决es存储命令时,父类根据filter_fields过滤出现异常的问题,返回的queryset类型list return queryset def get_date_range(self): now = timezone.now() days_ago = now - timezone.timedelta(days=self.default_days_ago) date_from_st = days_ago.timestamp() date_to_st = now.timestamp() query_params = self.request.query_params date_from_q = query_params.get("date_from") date_to_q = query_params.get("date_to") dt_parser = DateTimeField().to_internal_value if date_from_q: date_from_st = dt_parser(date_from_q).timestamp() if date_to_q: date_to_st = dt_parser(date_to_q).timestamp() return date_from_st, date_to_st class CommandViewSet(JMSBulkModelViewSet): """接受app发送来的command log, 格式如下 { "user": "admin", "asset": "localhost", "account": "web", "session": "xxxxxx", "input": "whoami", "output": "d2hvbWFp", # base64.b64encode(s) "timestamp": 1485238673.0 } """ command_store = get_command_storage() serializer_class = SessionCommandSerializer filterset_class = CommandFilter model = Command search_fields = ('input',) ordering_fields = ('timestamp', 'risk_level') def merge_all_storage_list(self, request, *args, **kwargs): merged_commands = [] storages = CommandStorage.objects.all() for storage in storages: if not storage.is_valid(): continue qs = storage.get_command_queryset() commands = self.filter_queryset(qs) merged_commands.extend(commands[:]) # ES 默认只取 10 条数据 order = self.request.query_params.get('order', None) if order == 'timestamp': merged_commands.sort(key=lambda command: command.timestamp) else: merged_commands.sort(key=lambda command: command.timestamp, reverse=True) page = self.paginate_queryset(merged_commands) if page is not None: serializer = self.get_serializer(page, many=True) return self.get_paginated_response(serializer.data) serializer = self.get_serializer(merged_commands, many=True) return Response(serializer.data) def list(self, request, *args, **kwargs): command_storage_id = self.request.query_params.get('command_storage_id') session_id = self.request.query_params.get('session_id') if session_id and not command_storage_id: # 会话里的命令列表肯定会提供 session_id,这里防止 merge 的时候取全量的数据 return self.merge_all_storage_list(request, *args, **kwargs) queryset = self.get_queryset() queryset = self.filter_queryset(queryset) page = self.paginate_queryset(queryset) if page is not None: page = self.load_remote_addr(page) serializer = self.get_serializer(page, many=True) return self.get_paginated_response(serializer.data) # 适配像 ES 这种没有指定分页只返回少量数据的情况 queryset = queryset[:] queryset = self.load_remote_addr(queryset) serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) def load_remote_addr(self, queryset): commands = list(queryset) session_ids = {command.session for command in commands} sessions = Session.objects.filter(id__in=session_ids).values_list('id', 'remote_addr') session_addr_map = {str(i): addr for i, addr in sessions} for command in commands: command.remote_addr = session_addr_map.get(command.session, '') return commands def get_queryset(self): command_storage_id = self.request.query_params.get('command_storage_id') if not command_storage_id: return Command.objects.none() storage = CommandStorage.objects.get(id=command_storage_id) if not storage.is_valid(): raise StorageInvalid else: qs = storage.get_command_queryset() return qs def create(self, request, *args, **kwargs): serializer = self.serializer_class(data=request.data, many=True) if serializer.is_valid(): ok = self.command_store.bulk_save(serializer.validated_data) if ok: return Response("ok", status=201) else: return Response("Save error", status=500) else: msg = "Command not valid: {}".format(serializer.errors) logger.error(msg) return Response({"msg": msg}, status=401) class InsecureCommandAlertAPI(generics.CreateAPIView): serializer_class = InsecureCommandAlertSerializer rbac_perms = { 'POST': 'terminal.add_command' } def post(self, request, *args, **kwargs): serializer = InsecureCommandAlertSerializer(data=request.data, many=True) serializer.is_valid(raise_exception=True) commands = serializer.validated_data session_ids, acl_ids, cmd_group_ids = set(), set(), set() for command in commands: session_ids.add(command.get('session')) acl_ids.add(command.get('cmd_filter_acl')) cmd_group_ids.add(command.get('cmd_group')) sessions = Session.objects.filter(id__in=session_ids).only( 'id', 'org_id', 'asset', 'asset_id', 'user', 'user_id', 'account', 'account_id' ) session_mapper = {str(i.id): i for i in sessions} acls = CommandFilterACL.objects.filter(id__in=acl_ids).only('id', 'name', 'reviewers') acl_mapper = {str(i.id): i for i in acls} cmd_groups = CommandGroup.objects.filter(id__in=cmd_group_ids).only('id', 'name') cmd_group_mapper = {str(i.id): i for i in cmd_groups} for command in commands: cmd_acl = acl_mapper.get(command['cmd_filter_acl']) command['_cmd_filter_acl'] = cmd_acl cmd_group = cmd_group_mapper.get(command['cmd_group']) command['_cmd_group'] = cmd_group session = session_mapper.get(command['session']) risk_level = command.get('risk_level') if session: command.update({ '_account': session.account, '_org_name': session.org.name, }) if risk_level in [RiskLevelChoices.reject, RiskLevelChoices.review_reject]: CommandAlertMessage(command).publish_async() elif risk_level in [RiskLevelChoices.warning]: for reviewer in cmd_acl.reviewers.all(): CommandWarningMessage(reviewer, command).publish_async() else: logger.info(f'Risk level ignore: {RiskLevelChoices.get_label(risk_level)}({risk_level})') return Response({'msg': 'ok'})