mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			204 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			204 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
# -*- coding: utf-8 -*-
 | 
						|
#
 | 
						|
from django.core.exceptions import ObjectDoesNotExist
 | 
						|
from django.utils.translation import gettext_lazy as _
 | 
						|
from rest_framework import serializers
 | 
						|
from rest_framework.fields import ChoiceField, empty
 | 
						|
 | 
						|
from common.db.fields import TreeChoices
 | 
						|
from common.local import add_encrypted_field_set
 | 
						|
from common.utils import decrypt_password
 | 
						|
 | 
						|
__all__ = [
 | 
						|
    "ReadableHiddenField",
 | 
						|
    "EncryptedField",
 | 
						|
    "LabeledChoiceField",
 | 
						|
    "ObjectRelatedField",
 | 
						|
    "BitChoicesField",
 | 
						|
    "TreeChoicesField",
 | 
						|
    "LabeledMultipleChoiceField",
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
# ReadableHiddenField
 | 
						|
# -------------------
 | 
						|
 | 
						|
 | 
						|
class ReadableHiddenField(serializers.HiddenField):
 | 
						|
    """可读的 HiddenField"""
 | 
						|
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        super().__init__(**kwargs)
 | 
						|
        self.write_only = False
 | 
						|
 | 
						|
    def to_representation(self, value):
 | 
						|
        if hasattr(value, "id"):
 | 
						|
            return getattr(value, "id")
 | 
						|
        return value
 | 
						|
 | 
						|
 | 
						|
class EncryptedField(serializers.CharField):
 | 
						|
    def __init__(self, write_only=None, **kwargs):
 | 
						|
        if write_only is None:
 | 
						|
            write_only = True
 | 
						|
        kwargs["write_only"] = write_only
 | 
						|
        encrypted_key = kwargs.pop('encrypted_key', None)
 | 
						|
        super().__init__(**kwargs)
 | 
						|
        add_encrypted_field_set(encrypted_key or self.label)
 | 
						|
 | 
						|
    def to_internal_value(self, value):
 | 
						|
        value = super().to_internal_value(value)
 | 
						|
        return decrypt_password(value)
 | 
						|
 | 
						|
 | 
						|
class LabeledChoiceField(ChoiceField):
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        super(LabeledChoiceField, self).__init__(*args, **kwargs)
 | 
						|
        self.choice_mapper = {
 | 
						|
            key: value for key, value in self.choices.items()
 | 
						|
        }
 | 
						|
 | 
						|
    def to_representation(self, key):
 | 
						|
        if key is None:
 | 
						|
            return key
 | 
						|
        label = self.choice_mapper.get(key)
 | 
						|
        return {"value": key, "label": label}
 | 
						|
 | 
						|
    def to_internal_value(self, data):
 | 
						|
        if isinstance(data, dict):
 | 
						|
            data = data.get("value")
 | 
						|
        return super(LabeledChoiceField, self).to_internal_value(data)
 | 
						|
 | 
						|
 | 
						|
class LabeledMultipleChoiceField(serializers.MultipleChoiceField):
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        super().__init__(**kwargs)
 | 
						|
        self.choice_mapper = {
 | 
						|
            key: value for key, value in self.choices.items()
 | 
						|
        }
 | 
						|
 | 
						|
    def to_representation(self, keys):
 | 
						|
        if keys is None:
 | 
						|
            return keys
 | 
						|
        return [
 | 
						|
            {"value": key, "label": self.choice_mapper.get(key)}
 | 
						|
            for key in keys
 | 
						|
        ]
 | 
						|
 | 
						|
    def to_internal_value(self, data):
 | 
						|
        if not data:
 | 
						|
            return data
 | 
						|
 | 
						|
        if isinstance(data[0], dict):
 | 
						|
            return [item.get("value") for item in data]
 | 
						|
        else:
 | 
						|
            return data
 | 
						|
 | 
						|
 | 
						|
class ObjectRelatedField(serializers.RelatedField):
 | 
						|
    default_error_messages = {
 | 
						|
        "required": _("This field is required."),
 | 
						|
        "does_not_exist": _('Invalid pk "{pk_value}" - object does not exist.'),
 | 
						|
        "incorrect_type": _("Incorrect type. Expected pk value, received {data_type}."),
 | 
						|
    }
 | 
						|
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        self.attrs = kwargs.pop("attrs", None) or ("id", "name")
 | 
						|
        self.many = kwargs.get("many", False)
 | 
						|
        super().__init__(**kwargs)
 | 
						|
 | 
						|
    def to_representation(self, value):
 | 
						|
        data = {}
 | 
						|
        for attr in self.attrs:
 | 
						|
            if not hasattr(value, attr):
 | 
						|
                continue
 | 
						|
            data[attr] = getattr(value, attr)
 | 
						|
        return data
 | 
						|
 | 
						|
    def to_internal_value(self, data):
 | 
						|
        if not isinstance(data, dict):
 | 
						|
            pk = data
 | 
						|
        else:
 | 
						|
            pk = data.get("id") or data.get("pk") or data.get(self.attrs[0])
 | 
						|
        queryset = self.get_queryset()
 | 
						|
        try:
 | 
						|
            if isinstance(data, bool):
 | 
						|
                raise TypeError
 | 
						|
            return queryset.get(pk=pk)
 | 
						|
        except ObjectDoesNotExist:
 | 
						|
            self.fail("does_not_exist", pk_value=pk)
 | 
						|
        except (TypeError, ValueError):
 | 
						|
            self.fail("incorrect_type", data_type=type(pk).__name__)
 | 
						|
 | 
						|
 | 
						|
class TreeChoicesField(serializers.MultipleChoiceField):
 | 
						|
    def __init__(self, choice_cls, **kwargs):
 | 
						|
        assert issubclass(choice_cls, TreeChoices)
 | 
						|
        choices = [(c.name, c.label) for c in choice_cls]
 | 
						|
        self.tree = choice_cls.tree()
 | 
						|
        self._choice_cls = choice_cls
 | 
						|
        super().__init__(choices=choices, **kwargs)
 | 
						|
 | 
						|
    def to_internal_value(self, data):
 | 
						|
        if not data:
 | 
						|
            return data
 | 
						|
        if isinstance(data[0], dict):
 | 
						|
            return [item.get("value") for item in data]
 | 
						|
        else:
 | 
						|
            return data
 | 
						|
 | 
						|
 | 
						|
class BitChoicesField(TreeChoicesField):
 | 
						|
    """
 | 
						|
    位字段
 | 
						|
    """
 | 
						|
 | 
						|
    def to_representation(self, value):
 | 
						|
        if isinstance(value, list) and len(value) == 1:
 | 
						|
            # Swagger 会使用 field.choices.keys() 迭代传递进来
 | 
						|
            return [
 | 
						|
                {"value": c.name, "label": c.label}
 | 
						|
                for c in self._choice_cls
 | 
						|
                if c.name == value[0]
 | 
						|
            ]
 | 
						|
        return [
 | 
						|
            {"value": c.name, "label": c.label}
 | 
						|
            for c in self._choice_cls
 | 
						|
            if c.value & value == c.value
 | 
						|
        ]
 | 
						|
 | 
						|
    def to_internal_value(self, data):
 | 
						|
        if not isinstance(data, list):
 | 
						|
            raise serializers.ValidationError(_("Invalid data type, should be list"))
 | 
						|
        value = 0
 | 
						|
        if not data:
 | 
						|
            return value
 | 
						|
        if isinstance(data[0], dict):
 | 
						|
            data = [d["value"] for d in data]
 | 
						|
        # 所有的
 | 
						|
        if "all" in data:
 | 
						|
            for c in self._choice_cls:
 | 
						|
                value |= c.value
 | 
						|
            return value
 | 
						|
 | 
						|
        name_value_map = {c.name: c.value for c in self._choice_cls}
 | 
						|
        for name in data:
 | 
						|
            if name not in name_value_map:
 | 
						|
                raise serializers.ValidationError(_("Invalid choice: {}").format(name))
 | 
						|
            value |= name_value_map[name]
 | 
						|
        return value
 | 
						|
 | 
						|
    def run_validation(self, data=empty):
 | 
						|
        """
 | 
						|
        备注:
 | 
						|
        创建授权规则不包含 actions 字段时, 会使用默认值(AssetPermission 中设置),
 | 
						|
        会直接使用 ['connect', '...'] 等字段保存到数据库,导致类型错误
 | 
						|
        这里将获取到的值再执行一下 to_internal_value 方法, 转化为内部值
 | 
						|
        """
 | 
						|
        data = super().run_validation(data)
 | 
						|
        if isinstance(data, int):
 | 
						|
            return data
 | 
						|
        value = self.to_internal_value(data)
 | 
						|
        self.run_validators(value)
 | 
						|
        return value
 |