fix: 解决 swagger api 报错的问题

pull/9616/head^2
Bai 2023-02-17 21:09:20 +08:00 committed by Jiangjie.Bai
parent 2b29ce69f4
commit 1035e27201
3 changed files with 24 additions and 12 deletions

View File

@ -312,7 +312,7 @@ class DatesLoginMetricMixin:
@lazyproperty
def commands_danger_amount(self):
return self.command_queryset.filter(risk_level=Command.RISK_LEVEL_DANGEROUS).count()
return self.command_queryset.filter(risk_level=Command.RiskLevelChoices.dangerous).count()
@lazyproperty
def job_logs_running_amount(self):

View File

@ -10,12 +10,10 @@ from orgs.mixins.models import OrgModelMixin
class AbstractSessionCommand(OrgModelMixin):
RISK_LEVEL_ORDINARY = 0
RISK_LEVEL_DANGEROUS = 5
RISK_LEVEL_CHOICES = (
(RISK_LEVEL_ORDINARY, _('Ordinary')),
(RISK_LEVEL_DANGEROUS, _('Dangerous')),
)
class RiskLevelChoices(models.IntegerChoices):
ordinary = 0, _('Ordinary')
dangerous = 5, _('Dangerous')
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
user = models.CharField(max_length=64, db_index=True, verbose_name=_("User"))
asset = models.CharField(max_length=128, db_index=True, verbose_name=_("Asset"))
@ -23,7 +21,10 @@ class AbstractSessionCommand(OrgModelMixin):
input = models.CharField(max_length=128, db_index=True, verbose_name=_("Input"))
output = models.CharField(max_length=1024, blank=True, verbose_name=_("Output"))
session = models.CharField(max_length=36, db_index=True, verbose_name=_("Session"))
risk_level = models.SmallIntegerField(default=RISK_LEVEL_ORDINARY, choices=RISK_LEVEL_CHOICES, db_index=True, verbose_name=_("Risk level"))
risk_level = models.SmallIntegerField(
default=RiskLevelChoices.ordinary, choices=RiskLevelChoices.choices, db_index=True,
verbose_name=_("Risk level")
)
timestamp = models.IntegerField(db_index=True)
class Meta:
@ -44,7 +45,7 @@ class AbstractSessionCommand(OrgModelMixin):
@classmethod
def get_risk_level_str(cls, risk_level):
risk_mapper = dict(cls.RISK_LEVEL_CHOICES)
risk_mapper = dict(cls.RiskLevelChoices.choices)
return risk_mapper.get(risk_level)
def to_dict(self):

View File

@ -5,11 +5,12 @@ from rest_framework import serializers
from common.utils import pretty_string
from common.serializers.fields import LabeledChoiceField
from terminal.backends.command.models import AbstractSessionCommand
from terminal.models import Command
__all__ = ['SessionCommandSerializer', 'InsecureCommandAlertSerializer']
class SimpleSessionCommandSerializer(serializers.Serializer):
class SimpleSessionCommandSerializer(serializers.ModelSerializer):
""" 简单Session命令序列类, 用来提取公共字段 """
user = serializers.CharField(label=_("User")) # 限制 64 字符,见 validate_user
@ -17,12 +18,18 @@ class SimpleSessionCommandSerializer(serializers.Serializer):
input = serializers.CharField(max_length=2048, label=_("Command"))
session = serializers.CharField(max_length=36, label=_("Session ID"))
risk_level = LabeledChoiceField(
required=False, label=_("Risk level"), choices=AbstractSessionCommand.RISK_LEVEL_CHOICES
choices=AbstractSessionCommand.RiskLevelChoices.choices,
required=False, label=_("Risk level"),
)
org_id = serializers.CharField(
max_length=36, required=False, default='', allow_null=True, allow_blank=True
)
class Meta:
# 继承 ModelSerializer 解决 swagger risk_level type 为 object 的问题
model = Command
fields = ['user', 'asset', 'input', 'session', 'risk_level', 'org_id']
def validate_user(self, value):
if len(value) > 64:
value = value[:32] + value[-32:]
@ -51,5 +58,9 @@ class SessionCommandSerializerMixin(serializers.Serializer):
class SessionCommandSerializer(SessionCommandSerializerMixin, SimpleSessionCommandSerializer):
""" 字段排序序列类 """
pass
class Meta(SimpleSessionCommandSerializer.Meta):
fields = SimpleSessionCommandSerializer.Meta.fields + [
'id', 'account', 'output', 'timestamp', 'timestamp_display', 'remote_addr'
]