# -*- coding: utf-8 -*- # import ipaddress import json import logging import re from django.apps import apps from django.core.exceptions import ValidationError from django.core.validators import MinValueValidator, MaxValueValidator from django.db import models from django.db.models import Q, Manager from django.utils.encoding import force_text from django.utils.translation import ugettext_lazy as _ from rest_framework.utils.encoders import JSONEncoder from common.local import add_encrypted_field_set from common.utils import signer, crypto, contains_ip from .validators import PortRangeValidator __all__ = [ "JsonMixin", "JsonDictMixin", "JsonListMixin", "JsonTypeMixin", "JsonCharField", "JsonTextField", "JsonListCharField", "JsonListTextField", "JsonDictCharField", "JsonDictTextField", "EncryptCharField", "EncryptTextField", "EncryptMixin", "EncryptJsonDictTextField", "EncryptJsonDictCharField", "PortField", "PortRangeField", "BitChoices", "TreeChoices", "JSONManyToManyField", ] class JsonMixin: tp = None @staticmethod def json_decode(data): try: return json.loads(data) except (TypeError, json.JSONDecodeError): return None @staticmethod def json_encode(data): return json.dumps(data, cls=JSONEncoder) def from_db_value(self, value, expression, connection, context=None): if value is None: return value return self.json_decode(value) def to_python(self, value): if value is None: return value if not isinstance(value, str) or not value.startswith('"'): return value else: return self.json_decode(value) def get_prep_value(self, value): if value is None: return value return self.json_encode(value) class JsonTypeMixin(JsonMixin): tp = dict def from_db_value(self, value, expression, connection, context=None): value = super().from_db_value(value, expression, connection, context) if not isinstance(value, self.tp): value = self.tp() return value def to_python(self, value): data = super().to_python(value) if not isinstance(data, self.tp): data = self.tp() return data def get_prep_value(self, value): if not isinstance(value, self.tp): value = self.tp() return self.json_encode(value) class JsonDictMixin(JsonTypeMixin): tp = dict class JsonDictCharField(JsonDictMixin, models.CharField): description = _("Marshal dict data to char field") class JsonDictTextField(JsonDictMixin, models.TextField): description = _("Marshal dict data to text field") class JsonListMixin(JsonTypeMixin): tp = list class JsonStrListMixin(JsonListMixin): pass class JsonListCharField(JsonListMixin, models.CharField): description = _("Marshal list data to char field") class JsonListTextField(JsonListMixin, models.TextField): description = _("Marshal list data to text field") class JsonCharField(JsonMixin, models.CharField): description = _("Marshal data to char field") class JsonTextField(JsonMixin, models.TextField): description = _("Marshal data to text field") class EncryptMixin: """ EncryptMixin要放在最前面 """ def decrypt_from_signer(self, value): return signer.unsign(value) or "" def from_db_value(self, value, expression, connection, context=None): if value is None: return value value = force_text(value) plain_value = crypto.decrypt(value) # 如果没有解开,使用原来的signer解密 if not plain_value: plain_value = self.decrypt_from_signer(value) # 可能和Json mix,所以要先解密,再json sp = super() if hasattr(sp, "from_db_value"): plain_value = sp.from_db_value(plain_value, expression, connection, context) return plain_value def get_prep_value(self, value): if value is None: return value # 先 json 再解密 sp = super() if hasattr(sp, "get_prep_value"): value = sp.get_prep_value(value) value = force_text(value) # 替换新的加密方式 return crypto.encrypt(value) class EncryptTextField(EncryptMixin, models.TextField): description = _("Encrypt field using Secret Key") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) add_encrypted_field_set(self.verbose_name) class EncryptCharField(EncryptMixin, models.CharField): @staticmethod def change_max_length(kwargs): kwargs.setdefault("max_length", 1024) max_length = kwargs.get("max_length") if max_length < 129: max_length = 128 max_length = max_length * 2 kwargs["max_length"] = max_length def __init__(self, *args, **kwargs): self.change_max_length(kwargs) super().__init__(*args, **kwargs) add_encrypted_field_set(self.verbose_name) def deconstruct(self): name, path, args, kwargs = super().deconstruct() max_length = kwargs.pop("max_length") if max_length > 255: max_length = max_length // 2 kwargs["max_length"] = max_length return name, path, args, kwargs class EncryptJsonDictTextField(EncryptMixin, JsonDictTextField): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) add_encrypted_field_set(self.verbose_name) class EncryptJsonDictCharField(EncryptMixin, JsonDictCharField): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) add_encrypted_field_set(self.verbose_name) class PortField(models.IntegerField): def __init__(self, *args, **kwargs): kwargs.update( { "blank": False, "null": False, "validators": [MinValueValidator(0), MaxValueValidator(65535)], } ) super().__init__(*args, **kwargs) class TreeChoices(models.Choices): @classmethod def is_tree(cls): return True @classmethod def branches(cls): return [i for i in cls] @classmethod def tree(cls): if not cls.is_tree(): return [] root = [_("All"), cls.branches()] return [cls.render_node(root)] @classmethod def render_node(cls, node): if isinstance(node, models.Choices): return { "value": node.name, "label": node.label, } else: name, children = node return { "value": name, "label": name, "children": [cls.render_node(child) for child in children], } @classmethod def all(cls): return [i[0] for i in cls.choices] class BitChoices(models.IntegerChoices, TreeChoices): @classmethod def is_tree(cls): return False @classmethod def all(cls): value = 0 for c in cls: value |= c.value return value class PortRangeField(models.CharField): def __init__(self, **kwargs): kwargs['max_length'] = 16 super().__init__(**kwargs) self.validators.append(PortRangeValidator()) class RelatedManager: def __init__(self, instance, field): self.instance = instance self.field = field self.value = None def set(self, value): self.value = value self.instance.__dict__[self.field.name] = value @classmethod def get_filter_q(cls, value, to_model): if not value or not isinstance(value, dict): return Q() if value["type"] == "all": return Q() elif value["type"] == "ids" and isinstance(value.get("ids"), list): return Q(id__in=value["ids"]) elif value["type"] == "attrs" and isinstance(value.get("attrs"), list): return cls._get_filter_attrs_q(value, to_model) else: return Q() @classmethod def filter_queryset_by_model(cls, value, to_model): if hasattr(to_model, "get_queryset"): queryset = to_model.get_queryset() else: queryset = to_model.objects.all() q = cls.get_filter_q(value, to_model) return queryset.filter(q).distinct() @staticmethod def get_ip_in_q(name, val): q = Q() if isinstance(val, str): val = [val] for ip in val: if not ip: continue try: if ip == '*': return Q() elif '/' in ip: network = ipaddress.ip_network(ip) ips = network.hosts() q |= Q(**{"{}__in".format(name): ips}) elif '-' in ip: start_ip, end_ip = ip.split('-') start_ip = ipaddress.ip_address(start_ip) end_ip = ipaddress.ip_address(end_ip) q |= Q(**{"{}__range".format(name): (start_ip, end_ip)}) elif len(ip.split('.')) == 4: q |= Q(**{"{}__exact".format(name): ip}) else: q |= Q(**{"{}__startswith".format(name): ip}) except ValueError: continue return q @classmethod def _get_filter_attrs_q(cls, value, to_model): filters = Q() # 特殊情况有这几种, # 1. 像 资产中的 type 和 category,集成自 Platform。所以不能直接查询 # 2. 像 资产中的 nodes,不是简单的 m2m,是树 的关系 # 3. 像 用户中的 orgs 也不是简单的 m2m,也是计算出来的 # get_filter_{}_attr_q 处理复杂的 custom_attr_filter = getattr(to_model, "get_json_filter_attr_q", None) for attr in value["attrs"]: if not isinstance(attr, dict): continue name = attr.get('name') val = attr.get('value') match = attr.get('match', 'exact') if name is None or val is None: continue if custom_attr_filter: custom_filter_q = custom_attr_filter(name, val, match) if custom_filter_q: filters &= custom_filter_q continue if match == 'ip_in': q = cls.get_ip_in_q(name, val) elif match in ("exact", "contains", "startswith", "endswith", "regex", "gte", "lte", "gt", "lt"): lookup = "{}__{}".format(name, match) q = Q(**{lookup: val}) elif match == "not": q = ~Q(**{name: val}) elif match in ['m2m', 'in']: if not isinstance(val, list): val = [val] q = Q() if '*' in val else Q(**{"{}__in".format(name): val}) else: q = Q() if val == '*' else Q(**{name: val}) filters &= q return filters def _get_queryset(self): to_model = apps.get_model(self.field.to) value = self.value return self.filter_queryset_by_model(value, to_model) def get_attr_q(self): q = self._get_filter_attrs_q(self.value, apps.get_model(self.field.to)) return q def all(self): return self._get_queryset() def filter(self, *args, **kwargs): queryset = self._get_queryset() return queryset.filter(*args, **kwargs) class JSONManyToManyDescriptor: def __init__(self, field): self.field = field self._is_setting = False def __get__(self, instance, owner=None): if instance is None: return self if not hasattr(instance, "_related_manager_cache"): instance._related_manager_cache = {} if self.field.name not in instance._related_manager_cache: manager = RelatedManager(instance, self.field) instance._related_manager_cache[self.field.name] = manager manager = instance._related_manager_cache[self.field.name] return manager def __set__(self, instance, value): if instance is None: return if not hasattr(instance, "_related_manager_cache"): instance._related_manager_cache = {} if self.field.name not in instance._related_manager_cache: manager = self.__get__(instance, instance.__class__) else: manager = instance._related_manager_cache[self.field.name] if isinstance(value, RelatedManager): value = value.value manager.set(value) def is_match(self, obj, attr_rules): # m2m 的情况 # 自定义的情况:比如 nodes, category res = True to_model = apps.get_model(self.field.to) src_model = self.field.model field_name = self.field.name custom_attr_filter = getattr(src_model, "get_filter_{}_attr_q".format(field_name), None) custom_q = Q() for rule in attr_rules: value = getattr(obj, rule['name'], '') rule_value = rule.get('value', '') rule_match = rule.get('match', 'exact') if custom_attr_filter: q = custom_attr_filter(rule['name'], rule_value, rule_match) if q: custom_q &= q continue if rule_match == 'in': res &= value in rule_value elif rule_match == 'exact': res &= value == rule_value elif rule_match == 'contains': res &= rule_value in value elif rule_match == 'startswith': res &= str(value).startswith(str(rule_value)) elif rule_match == 'endswith': res &= str(value).endswith(str(rule_value)) elif rule_match == 'regex': res &= re.match(rule_value, value) elif rule_match == 'not': res &= value != rule_value elif rule['match'] == 'gte': res &= value >= rule_value elif rule['match'] == 'lte': res &= value <= rule_value elif rule['match'] == 'gt': res &= value > rule_value elif rule['match'] == 'lt': res &= value < rule_value elif rule['match'] == 'ip_in': if isinstance(rule_value, str): rule_value = [rule_value] res &= contains_ip(value, rule_value) elif rule['match'] == 'm2m': if isinstance(value, Manager): value = value.values_list('id', flat=True) value = set(map(str, value)) rule_value = set(map(str, rule_value)) res &= rule_value.issubset(value) else: logging.error("unknown match: {}".format(rule['match'])) res &= False if not res: return res if custom_q: res &= to_model.objects.filter(custom_q).filter(id=obj.id).exists() return res def get_filter_q(self, instance): model_cls = self.field.model field_name = self.field.column q = Q(**{f'{field_name}__type': 'all'}) | \ Q(**{f'{field_name}__type': 'ids', f'{field_name}__ids__contains': [str(instance.id)]}) queryset_id_attrs = model_cls.objects \ .filter(**{'{}__type'.format(field_name): 'attrs'}) \ .values_list('id', '{}__attrs'.format(field_name)) ids = [str(_id) for _id, attr_rules in queryset_id_attrs if self.is_match(instance, attr_rules)] if ids: q |= Q(id__in=ids) return q class JSONManyToManyField(models.JSONField): def __init__(self, to, *args, **kwargs): self.to = to super().__init__(*args, **kwargs) def contribute_to_class(self, cls, name, **kwargs): super().contribute_to_class(cls, name, **kwargs) setattr(cls, self.name, JSONManyToManyDescriptor(self)) def deconstruct(self): name, path, args, kwargs = super().deconstruct() kwargs['to'] = self.to return name, path, args, kwargs @staticmethod def check_value(val): if not val: return val e = ValueError(_( "Invalid JSON data for JSONManyToManyField, should be like " "{'type': 'all'} or {'type': 'ids', 'ids': []} " "or {'type': 'attrs', 'attrs': [{'name': 'ip', 'match': 'exact', 'value': '1.1.1.1'}}" )) if not isinstance(val, dict): raise e if val["type"] not in ["all", "ids", "attrs"]: raise ValueError(_('Invalid type, should be "all", "ids" or "attrs"')) if val["type"] == "ids": if not isinstance(val["ids"], list): raise ValueError(_("Invalid ids for ids, should be a list")) elif val["type"] == "attrs": if not isinstance(val["attrs"], list): raise ValueError(_("Invalid attrs, should be a list of dict")) for attr in val["attrs"]: if not isinstance(attr, dict): raise ValueError(_("Invalid attrs, should be a list of dict")) if 'name' not in attr or 'value' not in attr: raise ValueError(_("Invalid attrs, should be has name and value")) def get_prep_value(self, value): if value is None: return None if isinstance(value, RelatedManager): value = value.value return json.dumps(value) def validate(self, value, model_instance): super().validate(value, model_instance) if not isinstance(value, dict): raise ValidationError("Invalid JSON data for JSONManyToManyField.") self.check_value(value)